diff --git a/.gitattributes b/.gitattributes index e8fc248b..985fc46b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,7 +5,10 @@ $.dcm binary *.tar.gz binary *.png binary +*.jpeg binary +*.jpg binary *.ttf binary +*.pickle binary *.woff binary *.ipynb filter=nbstripout diff --git a/.gitignore b/.gitignore index 9d92a7bc..c5331e0e 100644 --- a/.gitignore +++ b/.gitignore @@ -140,9 +140,17 @@ platipy/*/tests/data testing/ converted/ **/data +**/working **/tcia **/nifti_output services/*/*.db # Don't include html docs in repo -docs/site/ \ No newline at end of file +docs/site/ + +*.npy +*.nii.gz +valid*.png +ged*.png + +test_prob*/ \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 7f7aa386..68118e59 100644 --- a/.pylintrc +++ b/.pylintrc @@ -139,10 +139,11 @@ disable=print-statement, deprecated-sys-function, exception-escape, comprehension-escape, - C0330, - C0114, - W0102, - W0105 + bad-continuation, + missing-module-docstring, + # pointless-string-statement, + dangerous-default-value, + arguments-differ # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option @@ -443,7 +444,7 @@ contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. -generated-members= +generated-members=torch.*,pytorch_lightning.* # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). diff --git a/examples/experimental/nnunet_service.ipynb b/examples/experimental/nnunet_service.ipynb index 1f8bbc71..b7116437 100644 --- a/examples/experimental/nnunet_service.ipynb +++ b/examples/experimental/nnunet_service.ipynb @@ -26,19 +26,24 @@ "sys.path.append(\"../../..\")\n", "\n", "import os\n", + "from pathlib import Path\n", + "\n", + "import SimpleITK as sitk \n", + "import time\n", "\n", "from platipy.backend.client import PlatiPyClient\n", "from platipy.imaging.tests.data import get_lung_nifti\n", + "from platipy.imaging import ImageVisualiser\n", + "from platipy.imaging.label.utils import get_com\n", "\n", "from loguru import logger\n", "\n", "host = \"127.0.0.1\" # Set the host name or IP of the server running the service here\n", - "host = \"10.55.72.183\"\n", "port = 8001 # Set the port the service was configured to run on here\n", "\n", "api_key = 'XXX' # Put API key here\n", - "api_key = \"fc1858e6-4432-47a4-b3b6-6df0ff652c38\"\n", - "algorithm_name = \"nnUNet Segmentation\" # The name of the algorithm, in this case it should be left as is\n", + "\n", + "algorithm_name = \"nnUNet Service\" # The name of the algorithm, in this case it should be left as is\n", "\n", "log_level = \"INFO\" # Choose an appropriate level of logging output: \"DEBUG\" or \"INFO\"\n", "\n", @@ -121,9 +126,9 @@ "metadata": {}, "outputs": [], "source": [ - "pat_id = list(images.keys())[0]\n", - "ct_file = os.path.join(images[pat_id], \"CT.nii.gz\")\n", - "data_object = client.add_data_object(dataset, file_path=ct_file)" + "images = [i for i in lung_data.glob(\"*/IMAGES/*.nii.gz\")]\n", + "ct_image = str(images[0])\n", + "data_object = client.add_data_object(dataset, file_path=ct_image)" ] }, { @@ -159,38 +164,11 @@ "metadata": {}, "outputs": [], "source": [ - "images" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "atlas_cases = list(images.keys())[1:]\n", - "atlas_path = os.path.dirname(images[atlas_cases[0]])\n", - "\n", "settings = client.get_default_settings()\n", - "\n", - "# Atlas settings\n", - "settings[\"atlasSettings\"][\"atlasPath\"] = atlas_path\n", - "settings[\"atlasSettings\"][\"atlasStructures\"] = [\"Heart\",\"Lung_L\",\"Lung_R\"]\n", - "settings[\"atlasSettings\"][\"atlasIdList\"] = atlas_cases\n", - "settings[\"atlasSettings\"][\"atlasImageFormat\"] = '{0}/CT.nii.gz'\n", - "settings[\"atlasSettings\"][\"atlasLabelFormat\"] = '{0}/Struct_{1}.nii.gz' \n", - "\n", - "# Run the DIR a bit more than default\n", - "settings['deformableSettings']['iterationStaging'] = [75,50,50]\n", - "\n", - "# Run the IAR using the heart\n", - "settings[\"IARSettings\"][\"referenceStructure\"] = 'Lung_L' \n", - "\n", - "# Set the threshold\n", - "settings['labelFusionSettings'][\"optimalThreshold\"] = {\"Heart\":0.5, \"Lung_L\": 0.5, \"Lung_R\": 0.5}\n", - "\n", - "# No vessels\n", - "settings['vesselSpliningSettings']['vesselNameList'] = []" + "settings['task'] = \"Task200_ClinicalHeart\"\n", + "settings['config'] = \"3d_lowres\"\n", + "settings['trainer'] = \"nnUNetTrainerHeart\"\n", + "settings['clean_sup_slices'] = True" ] }, { @@ -210,8 +188,13 @@ }, "outputs": [], "source": [ + "start = time.time()\n", + "\n", "for status in client.run_algorithm(dataset, config=settings):\n", - " print('.', end='')" + " print('.', end='')\n", + "\n", + "end = time.time()\n", + "print(f\"Took {end - start:.1f} seconds\")" ] }, { @@ -229,10 +212,30 @@ "metadata": {}, "outputs": [], "source": [ - "output_directory = os.path.join(\".\", \"results\", pat_id)\n", + "output_directory = os.path.join(\".\", \"results\")\n", "client.download_output_objects(dataset, output_path=output_directory)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Display the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "heart = sitk.ReadImage(str([s for s in Path(output_directory).glob(\"*\")][0]))\n", + "\n", + "vis = ImageVisualiser(sitk.ReadImage(str(ct_image)), cut=get_com(heart))\n", + "vis.add_contour({\"Heart\": heart})\n", + "fig=vis.show()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -246,7 +249,8 @@ "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90" }, "kernelspec": { - "display_name": "Python 3.6.9 64-bit", + "display_name": "Python 3", + "language": "python", "name": "python3" }, "language_info": { @@ -259,9 +263,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.8.8" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/platipy/imaging/cnn/__init__.py b/platipy/imaging/cnn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/platipy/imaging/cnn/dataload.py b/platipy/imaging/cnn/dataload.py new file mode 100644 index 00000000..bfcb4fab --- /dev/null +++ b/platipy/imaging/cnn/dataload.py @@ -0,0 +1,478 @@ +import random +import math +import logging +from pathlib import Path + +import torch + +import pytorch_lightning as pl + +from platipy.imaging.cnn.dataset import NiftiDataset +from platipy.imaging.cnn.sampler import ObserverSampler + +logger = logging.getLogger(__name__) + + +class UNetDataModule(pl.LightningDataModule): + """PyTorch data module to training UNets""" + + def __init__( + self, + data_dir: str = "./data", + data_add_dirs: list = [], + augmented_dir: str = None, + augmented_add_dirs: list = [], + working_dir: str = "./working", + structures=["a", "b", "c"], + observers=["0", "1", "2", "3", "4"], + observers_add=[], + case_glob="images/*.nii.gz", + image_glob="images/{case}.nii.gz", + label_glob="labels/{case}_{structure}_*.nii.gz", + label_add_glob="labels/{case}_{structure}.nii.gz", + context_map_glob="context_maps/{case}.nii.gz", + augmented_case_glob="{case}/*", + augmented_image_glob="images/{augmented_case}.nii.gz", + augmented_label_glob="labels/{augmented_case}_{structure}_*.nii.gz", + augmented_label_add_glob="labels/{augmented_case}_{structure}_*.nii.gz", + augmented_context_map_glob="context_maps/{case}_{augmented_case}.nii.gz", + augment_on_fly=True, + fold=0, + k_folds=5, + batch_size=5, + num_workers=4, + crop_to_grid_size_xy=128, + intensity_scaling="window", + intensity_window=[-500, 500], + num_observers=5, + spacing=[1, 1, 1], + contour_mask_kernel=3, + crop_using_localise_model=None, + localise_voxel_grid_size=[100, 100, 100], + validation_sampler="observer", # observer or batch + input_channels=1, + ndims=2, + **kwargs, + ): + super().__init__() + self.data_dir = Path(data_dir) + self.data_add_dirs = [Path(p) for p in data_add_dirs] + self.augmented_dir = augmented_dir + self.augmented_add_dirs = augmented_add_dirs + self.working_dir = Path(working_dir) + + self.case_glob = case_glob + self.image_glob = image_glob + self.label_glob = label_glob + self.label_add_glob = label_add_glob + self.context_map_glob = context_map_glob + + self.augmented_case_glob = augmented_case_glob + self.augmented_image_glob = augmented_image_glob + self.augmented_label_glob = augmented_label_glob + self.augmented_label_add_glob = augmented_label_add_glob + self.augmented_context_map_glob = augmented_context_map_glob + + self.augment_on_fly = augment_on_fly + self.fold = fold + self.k_folds = k_folds + + self.train_cases = [] + self.validation_cases = [] + self.test_cases = [] + + self.batch_size = batch_size + self.num_workers = num_workers + self.crop_to_grid_size_xy = crop_to_grid_size_xy + self.num_observers = num_observers + self.spacing = spacing + self.intensity_scaling = intensity_scaling + self.intensity_window = intensity_window + self.contour_mask_kernel = contour_mask_kernel + self.structures = structures + self.observers = observers + self.observers_add = observers_add + + self.crop_using_localise_model = crop_using_localise_model + self.localise_voxel_grid_size = localise_voxel_grid_size + + self.training_set = None + self.validation_set = None + self.test_set = None + self.validation_sampler = validation_sampler + + self.validation_data = [] + self.test_data = [] + + self.input_channels = input_channels + self.ndims = ndims + + print(f"Training fold {self.fold}") + + @staticmethod + def add_model_specific_args(parent_parser): + """Add arguments used for Data module""" + parser = parent_parser.add_argument_group("Data Loader") + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--data_add_dirs", nargs="+", type=str, default=[]) + parser.add_argument("--augmented_dir", type=str, default=None) + parser.add_argument("--augmented_add_dirs", nargs="+", type=str, default=[]) + parser.add_argument("--augment_on_fly", type=bool, default=True) + parser.add_argument("--fold", type=int, default=0) + parser.add_argument("--k_folds", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=5) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument( + "--structures", nargs="+", type=str, default=["a", "b", "c"] + ) + parser.add_argument( + "--observers", nargs="+", type=str, default=["0", "1", "2", "3", "4"] + ) + parser.add_argument("--observers_add", nargs="+", type=str, default=[]) + parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") + parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") + parser.add_argument( + "--label_glob", + type=str, + default="labels/{case}_{structure}_{observer}.nii.gz", + ) + parser.add_argument( + "--label_add_glob", type=str, default="labels/{case}_{structure}.nii.gz" + ) + parser.add_argument("--context_map_glob", type=str, default=None) + parser.add_argument("--augmented_case_glob", type=str, default=None) + parser.add_argument("--augmented_image_glob", type=str, default=None) + parser.add_argument("--augmented_label_glob", type=str, default=None) + parser.add_argument("--augmented_label_add_glob", type=str, default=None) + parser.add_argument("--augmented_context_map_glob", type=str, default=None) + parser.add_argument("--crop_to_grid_size_xy", type=int, default=128) + parser.add_argument("--intensity_scaling", type=str, default="window") + parser.add_argument( + "--intensity_window", nargs="+", type=int, default=[-500, 500] + ) + parser.add_argument("--contour_mask_kernel", type=int, default=5) + parser.add_argument("--crop_using_localise_model", type=str, default=None) + parser.add_argument( + "--localise_voxel_grid_size", nargs="+", type=int, default=[100, 100, 100] + ) + parser.add_argument("--ndims", type=int, default=2) + + return parent_parser + + def setup(self, stage=None): + cases = [ + p.name.replace(".nii.gz", "") + for p in self.data_dir.glob(self.case_glob) + if not p.name.startswith(".") + ] + cases.sort() + random.shuffle(cases) # will be consistent for same value of 'seed everything' + cases_per_fold = math.ceil(len(cases) / self.k_folds) + + for f in range(self.k_folds): + if self.fold == f: + val_test_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] + + if len(val_test_cases) == 1: + self.validation_cases = val_test_cases + else: + self.validation_cases = val_test_cases[ + : int(len(val_test_cases) / 2) + ] + self.test_cases = val_test_cases[int(len(val_test_cases) / 2) :] + else: + self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] + + print(f"Training cases: {self.train_cases}") + print(f"Validation cases: {self.validation_cases}") + print(f"Testing cases: {self.test_cases}") + + train_data = [ + { + "id": case, + "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "context_map": None + if self.context_map_glob is None + else self.data_dir.joinpath(self.context_map_glob.format(case=case)), + "observers": { + observer: { + structure: self.data_dir.joinpath( + self.label_glob.format( + case=case, structure=structure, observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers + }, + } + for case in self.train_cases + ] + + # If a directory with augmented data is specified, use that for training as well + if self.augmented_dir is not None: + for case in self.train_cases: + case_aug_dir = Path(self.augmented_dir.format(case=case)) + augmented_cases = [ + p.name.replace(".nii.gz", "") + for p in case_aug_dir.glob( + self.augmented_case_glob.format(case=case) + ) + if not p.name.startswith(".") + ] + + train_data += [ + { + "id": f"{case}_{augmented_case}", + "image": case_aug_dir.joinpath( + self.augmented_image_glob.format( + case=case, augmented_case=augmented_case + ) + ), + "context_map": None + if self.augmented_context_map_glob is None + else case_aug_dir.joinpath( + self.augmented_context_map_glob.format( + case=case, augmented_case=augmented_case + ) + ), + "observers": { + observer: { + structure: case_aug_dir.joinpath( + self.augmented_label_glob.format( + case=case, + augmented_case=augmented_case, + structure=structure, + observer=observer, + ) + ) + for structure in self.structures + } + for observer in self.observers + }, + } + for augmented_case in augmented_cases + ] + + # If observers_add is empty then just add one dummy observer since they are not using + # Multi observer data here + if len(self.observers_add) == 0: + self.observers_add = ["X"] + + # Add in the addtional cases, these are only use for training and may only have 1 observer + for data_add_dir in self.data_add_dirs: + self.add_train_cases = [] + cases = [ + p.name.replace(".nii.gz", "") + for p in data_add_dir.glob(self.case_glob) + if not p.name.startswith(".") + ] + self.add_train_cases += cases + train_data += [ + { + "id": case, + "image": data_add_dir.joinpath(self.image_glob.format(case=case)), + "context_map": None + if self.context_map_glob is None + else data_add_dir.joinpath(self.context_map_glob.format(case=case)), + "observers": { + observer: { + structure: data_add_dir.joinpath( + self.label_add_glob.format( + case=case, structure=structure, observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers_add + }, + } + for case in cases + ] + + for case in cases: + case_aug_dir = None + for aug_add_dir in self.augmented_add_dirs: + if Path(aug_add_dir.format(case=case)).exists(): + case_aug_dir = Path(aug_add_dir.format(case=case)) + else: + print(f"No dir {Path(aug_add_dir.format(case=case))}") + + if case_aug_dir is None: + continue + + augmented_cases = [ + p.name.replace(".nii.gz", "") + for p in case_aug_dir.glob( + self.augmented_case_glob.format(case=case) + ) + if not p.name.startswith(".") + ] + print(augmented_cases) + + train_data += [ + { + "id": f"{case}_{augmented_case}", + "image": case_aug_dir.joinpath( + self.augmented_image_glob.format( + case=case, augmented_case=augmented_case + ) + ), + "context_map": None + if self.augmented_context_map_glob is None + else case_aug_dir.joinpath( + self.augmented_context_map_glob.format( + case=case, augmented_case=augmented_case + ) + ), + "observers": { + observer: { + structure: case_aug_dir.joinpath( + self.augmented_label_add_glob.format( + case=case, + augmented_case=augmented_case, + structure=structure, + observer=observer, + ) + ) + for structure in self.structures + } + for observer in self.observers_add + }, + } + for augmented_case in augmented_cases + ] + print(train_data) + print(len(train_data)) + + self.validation_data = [ + { + "id": case, + "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "context_map": None + if self.context_map_glob is None + else self.data_dir.joinpath(self.context_map_glob.format(case=case)), + "observers": { + observer: { + structure: self.data_dir.joinpath( + self.label_glob.format( + case=case, structure=structure, observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers + }, + } + for case in self.validation_cases + ] + print(self.validation_data) + + self.test_data = [ + { + "id": case, + "image": self.data_dir.joinpath(self.image_glob.format(case=case)), + "context_map": None + if self.context_map_glob is None + else self.data_dir.joinpath(self.context_map_glob.format(case=case)), + "observers": { + observer: { + structure: self.data_dir.joinpath( + self.label_glob.format( + case=case, structure=structure, observer=observer + ) + ) + for structure in self.structures + } + for observer in self.observers + }, + } + for case in self.test_cases + ] + + crop_to_grid_size = None + localise_model_path = None + if self.crop_using_localise_model: + localise_model_path = Path( + self.crop_using_localise_model.format(fold=self.fold) + ) + if localise_model_path.is_dir(): + localise_model_path = next(localise_model_path.glob("*.ckpt")) + + logger.info(f"Using localise model: {localise_model_path}") + crop_to_grid_size = self.localise_voxel_grid_size + else: + crop_to_grid_size = self.crop_to_grid_size_xy + + augment_on_fly = self.augment_on_fly + + use_context_map = False + if self.input_channels > 1: + use_context_map = True + + self.training_set = NiftiDataset( + train_data, + self.working_dir, + augment_on_fly=augment_on_fly, + spacing=self.spacing, + crop_to_grid_size=crop_to_grid_size, + crop_using_localise_model=localise_model_path, + contour_mask_kernel=self.contour_mask_kernel, + intensity_scaling=self.intensity_scaling, + intensity_window=self.intensity_window, + ndims=self.ndims, + use_context_map=use_context_map, + ) + self.validation_set = NiftiDataset( + self.validation_data, + self.working_dir, + augment_on_fly=False, + spacing=self.spacing, + crop_to_grid_size=crop_to_grid_size, + crop_using_localise_model=localise_model_path, + contour_mask_kernel=self.contour_mask_kernel, + intensity_scaling=self.intensity_scaling, + intensity_window=self.intensity_window, + ndims=self.ndims, + use_context_map=use_context_map, + ) + self.test_set = NiftiDataset( + self.test_data, + self.working_dir, + augment_on_fly=False, + spacing=self.spacing, + crop_to_grid_size=crop_to_grid_size, + crop_using_localise_model=localise_model_path, + contour_mask_kernel=self.contour_mask_kernel, + intensity_scaling=self.intensity_scaling, + intensity_window=self.intensity_window, + ndims=self.ndims, + use_context_map=use_context_map, + ) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + if self.validation_sampler == "observer": + return torch.utils.data.DataLoader( + self.validation_set, + batch_sampler=torch.utils.data.BatchSampler( + ObserverSampler(self.validation_set, self.num_observers), + batch_size=self.num_observers, + drop_last=False, + ), + num_workers=self.num_workers, + ) + else: + return torch.utils.data.DataLoader( + self.validation_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py new file mode 100644 index 00000000..ccfd211b --- /dev/null +++ b/platipy/imaging/cnn/dataset.py @@ -0,0 +1,694 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import logging +from pathlib import Path + +import numpy as np + +import torch + +import SimpleITK as sitk + +from imgaug import augmenters as iaa +from imgaug.augmentables.segmaps import SegmentationMapsOnImage + +import math +import random +from scipy.ndimage import affine_transform +from scipy.ndimage.filters import gaussian_filter, median_filter + +from platipy.imaging.cnn.utils import ( + preprocess_image, + resample_mask_to_image, + get_contour_mask, +) +from platipy.imaging.label.utils import get_union_mask, get_intersection_mask +from platipy.imaging.cnn.localise_net import LocaliseUNet +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi + +logger = logging.getLogger(__name__) + + +class GaussianNoise: + def __init__(self, mu=0.0, sigma=0.0, probability=1.0): + self.mu = mu + self.sigma = sigma + self.probability = probability + + if not hasattr(self.mu, "__iter__"): + self.mu = (self.mu,) * 2 + + if not hasattr(self.sigma, "__iter__"): + self.sigma = (self.sigma,) * 2 + + def apply(self, img, context_map, masks=[]): + if random.random() > self.probability: + # Don't augment this time + return img, context_map, masks + + mean = random.uniform(self.mu[0], self.mu[1]) + sigma = random.uniform(self.sigma[0], self.sigma[1]) + + gaussian = np.random.normal(mean, sigma, img.shape) + return img + gaussian, context_map, masks + + +class GaussianBlur: + def __init__(self, sigma=0.0, probability=1.0): + self.sigma = sigma + self.probability = probability + + if not hasattr(self.sigma, "__iter__"): + self.sigma = (self.sigma,) * 2 + + def apply(self, img, context_map, masks=[]): + if random.random() > self.probability: + # Don't augment this time + return img, context_map, masks + + sigma = random.uniform(self.sigma[0], self.sigma[1]) + + return gaussian_filter(img, sigma=sigma), context_map, masks + + +class MedianBlur: + def __init__(self, size=1.0, probability=1.0): + self.size = size + self.probability = probability + + if not hasattr(self.size, "__iter__"): + self.size = (self.size,) * 2 + + def apply(self, img, context_map, masks=[]): + if random.random() > self.probability: + # Don't augment this time + return img, context_map, masks + + size = random.randint(self.size[0], self.size[1]) + + return median_filter(img, size=size), context_map, masks + + +DIMS = ["ax", "cor", "sag"] + + +class Affine: + def __init__( + self, + scale={"ax": 1.0, "cor": 1.0, "sag": 1.0}, + translate_percent={"ax": 0.0, "cor": 0.0, "sag": 0.0}, + rotate={"ax": 0.0, "cor": 0.0, "sag": 0.0}, + shear={"ax": 0.0, "cor": 0.0, "sag": 0.0}, + mode="constant", + cval=-1, + probability=1.0, + ): + self.scale = scale + self.translate_percent = translate_percent + self.rotate = rotate + self.shear = shear + self.probability = probability + + for d in self.rotate: + if not hasattr(self.rotate[d], "__iter__"): + self.rotate[d] = (self.rotate[d],) * 2 + + for d in self.scale: + if not hasattr(self.scale[d], "__iter__"): + self.scale[d] = (self.scale[d],) * 2 + + for d in self.translate_percent: + if not hasattr(self.translate_percent[d], "__iter__"): + self.translate_percent[d] = (self.translate_percent[d],) * 2 + + for d in self.shear: + if not hasattr(self.shear[d], "__iter__"): + self.shear[d] = (self.shear[d],) * 2 + + for d in self.scale: + if not hasattr(self.scale[d], "__iter__"): + self.scale[d] = (self.scale[d],) * 2 + + def get_rot(self, theta, d): + if d == "ax": + return np.matrix( + [ + [1, 0, 0, 0], + [0, math.cos(theta), -math.sin(theta), 0], + [0, math.sin(theta), math.cos(theta), 0], + [0, 0, 0, 1], + ] + ) + + if d == "cor": + return np.matrix( + [ + [math.cos(theta), 0, math.sin(theta), 0], + [0, 1, 0, 0], + [-math.sin(theta), 0, math.cos(theta), 0], + [0, 0, 0, 1], + ] + ) + + if d == "sag": + return np.matrix( + [ + [math.cos(theta), -math.sin(theta), 0, 0], + [math.sin(theta), math.cos(theta), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + ) + + def get_shear(self, shear): + mat = np.identity(4) + mat[0, 1] = shear[1] + mat[0, 2] = shear[2] + mat[1, 0] = shear[0] + mat[1, 2] = shear[2] + mat[2, 0] = shear[0] + mat[2, 1] = shear[1] + + return mat + + def apply(self, img, context_map, masks=[]): + if random.random() > self.probability: + # Don't augment this time + return img, context_map, masks + + deg_to_rad = math.pi / 180 + + t_prerot = np.identity(4) + t_postrot = np.identity(4) + for i, d in enumerate(DIMS): + t_prerot[i, -1] = -img.shape[i] / 2 + t_postrot[i, -1] = img.shape[i] / 2 + + t = t_postrot + + for i, d in enumerate(DIMS): + t = t * self.get_rot( + random.uniform(self.rotate[d][0], self.rotate[d][1]) * deg_to_rad, d + ) + + for i, d in enumerate(DIMS): + scale = np.identity(4) + scale[i, i] = 1 / random.uniform(self.scale[d][0], self.scale[d][1]) + t = t * scale + + shear = [] + for i, d in enumerate(DIMS): + shear.append(random.uniform(self.shear[d][0], self.shear[d][1])) + + t = t * self.get_shear(shear) + + t = t * t_prerot + + for i, d in enumerate(DIMS): + trans = [p * img.shape[i] for p in self.translate_percent[d]] + translation = np.identity(4) + translation[i, -1] = random.uniform(trans[0], trans[1]) + t = t * translation + + augmented_image = affine_transform(img, t, mode="mirror") + augmented_context_map = None + if context_map is not None: + augmented_context_map = affine_transform(context_map, t, mode="mirror") + augmented_masks = [] + for mask in masks: + augmented_masks.append(affine_transform(mask, t, mode="nearest")) + + return augmented_image, augmented_context_map, augmented_masks + + +def crop_img_using_localise_model( + img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100], context_seg=None +): + """Crops an image using a LocaliseUNet + + Args: + img (SimpleITK.Image): The image to crop + localise_model (str|Path|LocaliseUNet): The LocaliseUNet or path to checkpoint of + LocaliseUNet. + spacing (list, optional): The image spacing (mm) to resample to. Defaults to [1,1,1]. + crop_to_grid_size (list, optional): The size of the grid to crop to. Defaults to + [100,100,100]. + context_seg (sitk.Image, optional): Use this segmentation instead of localise model if + provided. Defaults to None. + + Returns: + SimpleITK.Image: The cropped image. + """ + + if isinstance(localise_model, str): + localise_model = Path(localise_model) + + if context_seg is not None: + localise_pred = context_seg + else: + if isinstance(localise_model, Path): + if localise_model.is_dir(): + # Find the first actual model checkpoint in this directory + localise_model = next(localise_model.glob("*.ckpt")) + + localise_model = LocaliseUNet.load_from_checkpoint(localise_model) + + localise_model.eval() + localise_pred = localise_model.infer(img) + + img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None) + localise_pred = resample_mask_to_image(img, localise_pred) + size, index = label_to_roi(localise_pred) + + if not hasattr(crop_to_grid_size, "__iter__"): + crop_to_grid_size = (crop_to_grid_size,) * 3 + + index = [i - int((g - s) / 2) for i, s, g in zip(index, size, crop_to_grid_size)] + size = crop_to_grid_size + img_size = img.GetSize() + for i in range(3): + if index[i] + size[i] >= img_size[i]: + index[i] = img_size[i] - size[i] - 1 + if index[i] < 0: + index[i] = 0 + + return crop_to_roi(img, size, index) + + +def prepare_3d_transforms(): + affine_aug = Affine( + translate_percent={"ax": [-0.1, 0.1], "cor": [-0.1, 0.1], "sag": [-0.1, 0.1]}, + rotate={"ax": [-10.0, 10.0], "cor": [-10.0, 10.0], "sag": [-10.0, 10.0]}, + scale={"ax": [0.8, 1.2], "cor": [0.8, 1.2], "sag": [0.8, 1.2]}, + shear={"ax": [0.0, 0.2], "cor": [0.0, 0.2], "sag": [0.0, 0.2]}, + probability=0.5, + ) + gaussian_blur = GaussianBlur(sigma=[0.0, 1.0], probability=0.33) + median_blur = MedianBlur(size=[1, 3], probability=0.5) + gaussian_noise = GaussianNoise(sigma=[0, 0.2], probability=0.5) + + return [affine_aug, gaussian_blur, median_blur, gaussian_noise] + + +def prepare_transforms(): + sometimes = lambda aug: iaa.Sometimes(0.5, aug) + + seq = iaa.Sequential( + [ + sometimes( + iaa.Affine( + scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, + translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, + rotate=(-15, 15), + shear=(-8, 8), + cval=-1, + ) + ), + # execute 0 to 2 of the following (less important) augmenters per image + iaa.SomeOf( + (0, 2), + [ + iaa.OneOf( + [ + iaa.GaussianBlur((0, 1.5)), + iaa.AverageBlur(k=(3, 5)), + ] + ), + sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1))), + ], + random_order=True, + ), + sometimes(iaa.CoarseDropout((0.03, 0.15), size_percent=(0.02, 0.1))), + ], + random_order=True, + ) + + return seq + + +class NiftiDataset(torch.utils.data.Dataset): + """PyTorch Dataset for processing Nifti data""" + + def __init__( + self, + data, + working_dir, + augment_on_fly=True, + spacing=[1, 1, 1], + crop_using_localise_model=None, + crop_to_grid_size=128, + contour_mask_kernel=5, + combine_observers=None, + intensity_scaling="window", + intensity_window=[-500, 500], + use_context_map=False, + ndims=2, + ): + """Prepare a dataset from Nifti images/labels + + Args: + data (list): List of dict's where each item contains keys: "image" and "label". Values + are paths to the Nifti file. "label" may be a list where each item is a path to one + observer. + working_dir (str|path): Working directory where to write prepared files. + """ + + self.data = data + self.transforms = None + self.ndims = ndims + if augment_on_fly: + if self.ndims == 2: + self.transforms = prepare_transforms() + else: + self.transforms = prepare_3d_transforms() + self.slices = [] + self.working_dir = Path(working_dir) + + self.img_dir = working_dir.joinpath("img") + self.label_dir = working_dir.joinpath("label") + self.contour_mask_dir = working_dir.joinpath("contour_mask") + self.context_map_dir = working_dir.joinpath("context_map") + + self.img_dir.mkdir(exist_ok=True, parents=True) + self.label_dir.mkdir(exist_ok=True, parents=True) + self.contour_mask_dir.mkdir(exist_ok=True, parents=True) + self.context_map_dir.mkdir(exist_ok=True, parents=True) + + for case in data: + case_id = case["id"] + img_path = str(case["image"]) + cmap_file = None + + if use_context_map: + context_map_path = str(case["context_map"]) + + existing_images = [i for i in self.img_dir.glob(f"{case_id}_*.npy")] + if len(existing_images) > 0: + logger.debug(f"Image for case already exist: {case_id}") + + for img_path in existing_images: + z_matches = re.findall(rf"{case_id}_([0-9]*)\.npy", img_path.name) + if len(z_matches) == 0: + continue + z_slice = int(z_matches[0]) + + img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") + assert img_file.exists() + + cmap_file = None + if use_context_map: + cmap_file = self.context_map_dir.joinpath( + f"{case_id}_{z_slice}.npy" + ) + assert cmap_file.exists() + + for obs in case["observers"]: + labels = [] + contour_mask_files = [] + for structure in case["observers"][obs]: + label_file = self.label_dir.joinpath( + f"{case_id}_{structure}_{obs}_{z_slice}.npy" + ) + if label_file.exists(): + labels.append(label_file) + else: + labels.append(None) + + contour_mask_file = self.contour_mask_dir.joinpath( + f"{case_id}_{structure}_{z_slice}.npy" + ) + assert contour_mask_file.exists() + contour_mask_files.append(contour_mask_file) + + for oo in case["observers"]: + self.slices.append( + { + "z": z_slice, + "image": img_file, + "labels": labels, + "complabels": [self.label_dir.joinpath( + f"{case_id}_{structure}_{oo}_{z_slice}.npy" + ) for structure in case["observers"][obs]], + "contour_masks": contour_mask_files, + "context_map": cmap_file, + "case": case_id, + "observer": obs, + } + ) + + continue + + logger.debug(f"Generating images for case: {case_id}") + img = sitk.ReadImage(img_path) + + if use_context_map: + context_map = sitk.ReadImage(context_map_path) + + if crop_using_localise_model: + img = crop_img_using_localise_model( + img, + crop_using_localise_model, + spacing=spacing, + crop_to_grid_size=crop_to_grid_size, + ) + else: + img = preprocess_image( + img, + spacing=spacing, + crop_to_grid_size_xy=crop_to_grid_size, + intensity_scaling=intensity_scaling, + intensity_window=intensity_window, + ) + + if use_context_map: + context_map = resample_mask_to_image(img, context_map) + + observers = {} + structure_names = [] + for obs in case["observers"]: + observers[obs] = {} + for structure in case["observers"][obs]: + if structure not in structure_names: + structure_names.append(structure) + structure_path = case["observers"][obs][structure] + + label = None + if structure_path.exists(): + label = sitk.ReadImage(str(structure_path)) + label = resample_mask_to_image(img, label) + observers[obs][structure] = label + + contour_masks = {} + for structure in structure_names: + contour_masks[structure] = get_contour_mask( + [ + observers[obs][structure] + for obs in case["observers"] + if observers[obs][structure] is not None + ], + kernel=contour_mask_kernel, + ) + + if combine_observers: + updated_observers = {"": {}} + for structure in structure_names: + if combine_observers == "union": + updated_observers[""][structure] = [ + get_union_mask( + [ + observers[obs][structure] + for obs in case["observers"] + if observers[obs][structure] is not None + ] + ) + ] + elif combine_observers == "intersection": + updated_observers[""][structure] = [ + get_intersection_mask( + [ + observers[obs][structure] + for obs in case["observers"] + if observers[obs][structure] is not None + ] + ) + ] + else: + raise NotImplementedError( + "combine_observers should be 'union' or 'intersection'" + ) + + observers = updated_observers + + z_range = range(img.GetSize()[2]) + if ndims == 3: + z_range = range(1) + for z_slice in z_range: + # Save the image slice + if ndims == 2: + img_slice = img[:, :, z_slice] + + if use_context_map: + cmap_slice = context_map[:, :, z_slice] + else: + img_slice = img + if use_context_map: + cmap_slice = context_map + + img_file = self.img_dir.joinpath(f"{case_id}_{z_slice}.npy") + np.save(img_file, sitk.GetArrayFromImage(img_slice)) + + if use_context_map: + cmap_file = self.context_map_dir.joinpath( + f"{case_id}_{z_slice}.npy" + ) + np.save(cmap_file, sitk.GetArrayFromImage(cmap_slice)) + + # Save the contour mask slice + cmasks = [] + for structure in structure_names: + if ndims == 2: + contour_mask_slice = contour_masks[structure][:, :, z_slice] + else: + contour_mask_slice = contour_masks[structure] + contour_mask_file = self.contour_mask_dir.joinpath( + f"{case_id}_{structure}_{z_slice}.npy" + ) + np.save( + contour_mask_file, sitk.GetArrayFromImage(contour_mask_slice) + ) + cmasks.append(contour_mask_file) + + for obs in observers: + labels = [] + for structure in structure_names: + if observers[obs][structure] is None: + labels.append(None) + continue + if ndims == 2: + label_slice = observers[obs][structure][:, :, z_slice] + else: + label_slice = observers[obs][structure] + label_file = self.label_dir.joinpath( + f"{case_id}_{structure}_{obs}_{z_slice}.npy" + ) + np.save( + label_file, + sitk.GetArrayFromImage(label_slice).astype(np.int8), + ) + labels.append(label_file) + + # TODO allow enabling this + for oo in observers: + self.slices.append( + { + "z": z_slice, + "image": img_file, + "labels": labels, + "complabels": [self.label_dir.joinpath( + f"{case_id}_{structure}_{oo}_{z_slice}.npy" + ) for structure in structure_names], + "contour_masks": cmasks, + "context_map": cmap_file, + "case": case_id, + "observer": obs, + } + ) + + def __len__(self): + return len(self.slices) + + def __getitem__(self, index): + img = np.load(self.slices[index]["image"]) + labels = [ + np.load(label_file) if label_file else np.zeros(img.shape, dtype=np.ushort) + for label_file in self.slices[index]["labels"] + ] + complabels = [ + np.load(label_file) if label_file else np.zeros(img.shape, dtype=np.ushort) + for label_file in self.slices[index]["complabels"] + ] + contour_masks = [ + np.load(contour_mask_file) + for contour_mask_file in self.slices[index]["contour_masks"] + ] + + context_map = torch.Tensor() + use_context = False + if self.slices[index]["context_map"] is not None: + use_context = True + context_map = np.load(self.slices[index]["context_map"]) + + if self.transforms: + masks = labels + complabels + contour_masks + if self.ndims == 2: + seg_arr = np.concatenate([np.expand_dims(m, 2) for m in masks], 2) + segmap = SegmentationMapsOnImage(seg_arr, shape=labels[0].shape) + img, seg = self.transforms(image=img, segmentation_maps=segmap) + + # TODO Implement context map aug for 2D + if use_context: + raise NotImplementedError( + "WARNING!!! Augmentation for context map in 2D not yet implemented!" + ) + for idx, _ in enumerate(labels): + labels[idx] = seg.get_arr()[:, :, idx].squeeze() + contour_masks[idx] = seg.get_arr()[ + :, :, len(labels) + idx + ].squeeze() + else: + for aug in self.transforms: + if use_context: + img, context_map, masks = aug.apply(img, context_map, masks) + else: + img, _, masks = aug.apply(img, None, masks) + labels = masks[: len(labels)] + complabels = masks[len(labels) : len(complabels) + len(labels)] + contour_masks = masks[len(labels) + len(complabels) : ] + + img = torch.FloatTensor(img) + img = img.unsqueeze(0) + + if context_map is not None: + context_map = torch.FloatTensor(context_map) + context_map = context_map.unsqueeze(0) + + label = torch.FloatTensor( + np.concatenate([np.expand_dims(l, 0) for l in labels], 0).astype("int8") + ) + complabel = torch.FloatTensor( + np.concatenate([np.expand_dims(l, 0) for l in complabels], 0).astype("int8") + ) + contour_mask = torch.FloatTensor( + np.concatenate([np.expand_dims(l, 0) for l in contour_masks], 0).astype( + "int8" + ) + ) + contour_mask = contour_mask.max(axis=0).values.unsqueeze(0) + label_present = [label is not None for label in self.slices[index]["labels"]] + + return ( + img, + context_map, + label, + complabel, + contour_mask, + { + "case": str(self.slices[index]["case"]), + "observer": str(self.slices[index]["observer"]), + "label_present": label_present, + "z": self.slices[index]["z"], + }, + ) diff --git a/platipy/imaging/cnn/hierarchical_prob_unet.py b/platipy/imaging/cnn/hierarchical_prob_unet.py new file mode 100644 index 00000000..9679a049 --- /dev/null +++ b/platipy/imaging/cnn/hierarchical_prob_unet.py @@ -0,0 +1,947 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is adapted from +# https://github.com/deepmind/deepmind-research/tree/5cf55efe1f1748ebdd33cb69223b0df6bcc88e6a/hierarchical_probabilistic_unet +# which is released under the Apache Licence 2.0 + +# pylint: disable=invalid-name + +import torch + +from .unet import init_weights, init_zeros, conv_nd + + +class ResBlock(torch.nn.Module): + """A residual block""" + + def __init__( + self, + input_channels, + output_channels, + n_down_channels=None, + activation_fn=torch.nn.ReLU, + convs_per_block=2, + ndims=2, + ): + """Create a residual block + + Args: + input_channels (int): The number of input channels to the block + output_channels (int): The number of output channels from the block + n_down_channels (int, optional): The number of intermediate cahnnels within the block. + Defaults to the same as the number of output channels. + activation_fn (torch.nn.Module, optional): The activation function to apply. Defaults + to torch.nn.ReLU. + convs_per_block (int, optional): The number of convolutions to perform within the + block. Defaults to 2. + ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. + """ + + super(ResBlock, self).__init__() + + self._activation_fn = activation_fn() + + # Set the number of intermediate channels that we compress to. + if n_down_channels is None: + n_down_channels = output_channels + + layers = [] + in_channels = input_channels + for c in range(convs_per_block): + layers.append( + conv_nd( + ndims=ndims, + in_channels=in_channels, + out_channels=n_down_channels, + kernel_size=3, + padding=1, + ) + ) + + if c < convs_per_block - 1: + layers.append(activation_fn()) + + in_channels = n_down_channels + + if not n_down_channels == output_channels: + resize_outgoing = conv_nd( + ndims=ndims, + in_channels=n_down_channels, + out_channels=output_channels, + kernel_size=1, + padding=0, + ) + layers.append(resize_outgoing) + + self._layers = torch.nn.Sequential(*layers) +# self._layers.apply(init_weights) + + self._resize_skip = None + + if not input_channels == output_channels: + self._resize_skip = conv_nd( + ndims=ndims, + in_channels=input_channels, + out_channels=output_channels, + kernel_size=1, + padding=0, + ) + # self._resize_skip.apply(init_weights) + + def forward(self, input_features): + + # Pre-activate the inputs. + skip = input_features + residual = self._activation_fn(input_features) + + for layer in self._layers: + residual = layer(residual) + + if not self._resize_skip is None: + skip = self._resize_skip(skip) + + return skip + residual + + +def resize_up(input_features, scale=2): + """Resize the the input to upsample + + Args: + input_features (torch.Tensor): The Tensor to upsize + scale (int, optional): The scale used to upsize. Defaults to 2. + + Returns: + torch.Tensor: The upsized Tensor + """ + + input_shape = input_features.shape + size_x = input_shape[2] + size_y = input_shape[3] + + new_size = [int(round(size_x * scale)), int(round(size_y * scale))] + + if len(input_shape) == 5: + size_z = input_shape[4] + new_size = new_size + [int(round(size_z * scale))] + + return torch.nn.functional.interpolate(input_features, size=new_size) + + +def resize_down(input_features, scale=2): + """Resize the the input to downsample + + Args: + input_features (torch.Tensor): The Tensor to downsize + scale (int, optional): The scale used to downsize. Defaults to 2. + + Returns: + torch.Tensor: The downsized Tensor + """ + if input_features.ndim == 5: + return torch.nn.AvgPool3d(kernel_size=scale, stride=scale, padding=0)(input_features) + else: + return torch.nn.AvgPool2d(kernel_size=scale, stride=scale, padding=0)(input_features) + + +class _HierarchicalCore(torch.nn.Module): + """A U-Net encoder-decoder with a full encoder and a truncated decoder. + The truncated decoder is interleaved with the hierarchical latent space and + has as many levels as there are levels in the hierarchy plus one additional + level. + """ + + def __init__( + self, + latent_dims, + input_channels, + channels_per_block, + down_channels_per_block=None, + activation_fn=torch.nn.ReLU, + convs_per_block=2, + blocks_per_level=1, + ndims=2, + ): + """Initializes a HierarchicalCore. + + Args: + latent_dims (list): List of integers specifying the dimensions of the latents at + each scale. The length of the list indicates the number of U-Net + decoder scales that have latents. + input_channels (int): The number of input channels. + channels_per_block (list): A list of integers specifying the number of output + channels for each encoder block. + down_channels_per_block (list, optional): A list of integers specifying the number of + intermediate channels for each encoder block + or None. If None, the intermediate channels + are chosen equal to channels_per_block. + Defaults to None. + activation_fn (torch.nn.Module, optional): A callable activation function. Defaults to + torch.nn.ReLU. + convs_per_block (int, optional): An integer specifying the number of convolutional + layers. Defaults to 2. + blocks_per_level (int, optional): An integer specifying the number of residual blocks + per level. Defaults to 1. + ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. + """ + + super(_HierarchicalCore, self).__init__() + + self._latent_dims = latent_dims + self._input_channels = input_channels + self._channels_per_block = channels_per_block + self._activation_fn = activation_fn + self._convs_per_block = convs_per_block + self._blocks_per_level = blocks_per_level + if down_channels_per_block is None: + self._down_channels_per_block = channels_per_block + else: + self._down_channels_per_block = down_channels_per_block + + num_levels = len(self._channels_per_block) + self._num_latent_levels = len(self._latent_dims) + + # Iterate the descending levels in the U-Net encoder. + self.encoder_layers = torch.nn.ModuleList() + in_channels = input_channels + for level in range(num_levels): + # Iterate the residual blocks in each level. + layer = [] + for _ in range(self._blocks_per_level): + layer.append( + ResBlock( + in_channels, + channels_per_block[level], + n_down_channels=self._down_channels_per_block[level], + activation_fn=self._activation_fn, + convs_per_block=self._convs_per_block, + ndims=ndims, + ) + ) + in_channels = channels_per_block[level] + + self.encoder_layers.append(torch.nn.Sequential(*layer)) + + # self.encoder_layers.apply(init_weights) + + # Iterate the ascending levels in the (truncated) U-Net decoder. + self.decoder_layers = torch.nn.ModuleList() + self._mu_logsigma_blocks = torch.nn.ModuleList() + + for level in range(self._num_latent_levels): + + latent_dim = latent_dims[level] + + mu_logsigma_block = conv_nd( + ndims=ndims, + in_channels=channels_per_block[::-1][level], + out_channels=2 * latent_dim, + kernel_size=1, + padding=0, + ) + + self._mu_logsigma_blocks.append(mu_logsigma_block) + + decoder_in_channels = ( + channels_per_block[::-1][level + 1] + channels_per_block[::-1][level] + ) + latent_dim + layer = [] + for _ in range(self._blocks_per_level): + layer.append( + ResBlock( + decoder_in_channels, + channels_per_block[::-1][level + 1], + n_down_channels=self._down_channels_per_block[::-1][level + 1], + activation_fn=self._activation_fn, + convs_per_block=self._convs_per_block, + ndims=ndims, + ) + ) + decoder_in_channels = channels_per_block[::-1][level + 1] + + self.decoder_layers.append(torch.nn.Sequential(*layer)) + + # self._mu_logsigma_blocks.apply(init_zeros) + # self.decoder_layers.apply(init_weights) + + def forward(self, inputs, mean=False, std_devs_from_mean=0.0, z_q=None): + """Forward pass to sample from the module as specified. + + Args: + inputs (torch.Tensor): A tensor of shape (b,c,h,w). When using the module as a prior + the `inputs` tensor should be a batch of images. When using it + as a posterior the tensor should be a (batched) concatentation + of images and segmentations. + mean (bool|list, optional): A boolean or a list of booleans. If a boolean, it specifies + whether or not to use the distributions' means in ALL + latent scales. If a list, each bool therein specifies + whether or not to use the scale's mean. If False, the + latents of the scale are sampled. Defaults to False. + std_devs_from_mean (float|list, optional): A float or list of floats describing how far + from the mean should be sampled. Only at + scales where mean is True. Defaults to 0. + z_q (list, optional): None or a list of tensors. If not None, z_q provides external + latents to be used instead of sampling them. This is used to + employ posterior latents in the prior during training. Therefore, + if z_q is not None, the value of `mean` is ignored. If z_q is + None, either the distributions mean is used (in case `mean` for + the respective scale is True) or else a sample from the + distribution is drawn. Defaults to None. + + Returns: + dict: A Dictionary holding the output feature map of the truncated U-Net decoder under + key 'decoder_features', a list of the U-Net encoder features produced at the end of + each encoder scale under key 'encoder_outputs', a list of the predicted distributions + at each scale under key 'distributions', a list of the used latents at each scale under + the key 'used_latents'. + """ + + encoder_features = inputs + encoder_outputs = [] + num_levels = len(self._channels_per_block) + num_latent_levels = len(self._latent_dims) + + if isinstance(mean, bool): + mean = [mean] * self._num_latent_levels + + if isinstance(std_devs_from_mean, int): + std_devs_from_mean = float(std_devs_from_mean) + + if isinstance(std_devs_from_mean, float): + std_devs_from_mean = [std_devs_from_mean] * self._num_latent_levels + + distributions = [] + used_latents = [] + + # Iterate the descending levels in the U-Net encoder. + for level, encoder_layer in enumerate(self.encoder_layers): + encoder_features = encoder_layer(encoder_features) + encoder_outputs.append(encoder_features) + if not level == num_levels - 1: + encoder_features = resize_down(encoder_features, scale=2) + + # Iterate the ascending levels in the (truncated) U-Net decoder. + decoder_features = encoder_outputs[-1] + for level in range(num_latent_levels): + + # Predict a Gaussian distribution for each pixel in the feature map. + latent_dim = self._latent_dims[level] + mu_logsigma = self._mu_logsigma_blocks[level](decoder_features) + + mu = mu_logsigma[:, :latent_dim].clamp(-1000, 1000) + log_sigma = mu_logsigma[:, latent_dim:].clamp(-10, 10) + + dist = torch.distributions.Independent( + torch.distributions.Normal(loc=mu, scale=torch.exp(log_sigma)), 1 + ) + distributions.append(dist) + + # Get the latents to condition on. + if z_q is not None: + z = z_q[level] + elif mean[level]: + z = dist.mean + (dist.base_dist.stddev * std_devs_from_mean[level]) + else: + z = dist.sample() + + used_latents.append(z) + + # Concat and upsample the latents with the previous features. + decoder_output_lo = torch.cat([z, decoder_features], axis=1) + decoder_output_hi = resize_up(decoder_output_lo, scale=2) + decoder_features = torch.cat( + [decoder_output_hi, encoder_outputs[::-1][level + 1]], axis=1 + ) + decoder_features = self.decoder_layers[level](decoder_features) + + return { + "decoder_features": decoder_features, + "encoder_features": encoder_outputs, + "distributions": distributions, + "used_latents": used_latents, + } + + +class _StitchingDecoder(torch.nn.Module): + """A module that completes the truncated U-Net decoder. + Using the output of the HierarchicalCore this module fills in the missing + decoder levels such that together the two form a symmetric U-Net. + """ + + def __init__( + self, + latent_dims, + channels_per_block, + num_classes, + down_channels_per_block=None, + activation_fn=torch.nn.ReLU, + convs_per_block=2, + blocks_per_level=1, + ndims=2, + ): + """Initializes a StichtingDecoder. + + Args: + latent_dims (list): List of integers specifying the dimensions of the latents at each + scale. The length of the list indicates the number of U-Net decoder + scales that have latents. + channels_per_block (list): A list of integers specifying the number of output channels + for each encoder block. + num_classes (int): The number of segmentation classes. + down_channels_per_block ([type], optional): A list of integers specifying the number of + intermediate channels for each encoder + block. If None, the intermediate channels + are chosen equal to channels_per_block. + Defaults to None. + activation_fn (torch.nn.Module, optional): A callable activation function.Defaults to + torch.nn.ReLU. + initializers ([type], optional): [description]. Defaults to None. + regularizers ([type], optional): [description]. Defaults to None. + convs_per_block (int, optional): An integer specifying the number of convolutional + layers. Defaults to 2. + blocks_per_level (int, optional): An integer specifying the number of residual blocks + per level. Defaults to 1. + ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. + """ + super(_StitchingDecoder, self).__init__() + self._latent_dims = latent_dims + self._channels_per_block = channels_per_block + self._num_classes = num_classes + self._activation_fn = activation_fn + self._convs_per_block = convs_per_block + self._blocks_per_level = blocks_per_level + if down_channels_per_block is None: + down_channels_per_block = channels_per_block + self._down_channels_per_block = down_channels_per_block + + num_latents = len(self._latent_dims) + self._start_level = num_latents + 1 + self._num_levels = len(self._channels_per_block) + + self.layers = torch.nn.ModuleList() + decoder_in_channels = None + for level in range(self._start_level, self._num_levels, 1): + + decoder_in_channels = ( + channels_per_block[::-1][level - 1] + channels_per_block[::-1][level] + ) + + layer = [] + for _ in range(self._blocks_per_level): + layer.append( + ResBlock( + decoder_in_channels, + channels_per_block[::-1][level], + n_down_channels=self._down_channels_per_block[::-1][level], + activation_fn=self._activation_fn, + convs_per_block=self._convs_per_block, + ndims=ndims, + ) + ) + decoder_in_channels = channels_per_block[::-1][level] + + self.layers.append(torch.nn.Sequential(*layer)) + # self.layers.apply(init_weights) + + if decoder_in_channels is None: + decoder_in_channels = channels_per_block[::-1][self._num_levels - 1] + + self.final_layer = conv_nd( + ndims=ndims, + in_channels=decoder_in_channels, + out_channels=self._num_classes, + kernel_size=1, + padding=0, + ) + # self.final_layer.apply(init_weights) + + def forward(self, encoder_features, decoder_features): + """Forward pass through the stiching decoder + + Args: + encoder_features (torch.Tensor): Tensor of encoder features + decoder_features (dict): Tensor of decoder features + + Returns: + torch.Tensor: The stiched output + """ + + for level in range(len(self.layers)): + enc_level = self._start_level + level + decoder_features = resize_up(decoder_features, scale=2) + decoder_features = torch.cat( + [decoder_features, encoder_features[::-1][enc_level]], axis=1 + ) + decoder_features = self.layers[level](decoder_features) + + return self.final_layer(decoder_features) + + +class HierarchicalProbabilisticUnet(torch.nn.Module): + """A hierarchical probabilistic UNet implementation: https://arxiv.org/abs/1905.13077""" + + def __init__( + self, + input_channels=1, + num_classes=2, + filters_per_layer=None, + down_channels_per_block=None, + latent_dims=(1, 1, 1, 1), + convs_per_block=2, + blocks_per_level=1, + loss_type="elbo", + loss_params={"beta": 1}, + ndims=2, + ): + """Initialize the Hierarchical Probabilistic UNet + + Args: + input_channels (int, optional): The number of channels in the image (1 for + greyscale and 3 for RGB). Defaults to 1. + num_classes (int, optional): The number of classes to predict. Defaults to 2. + filters_per_layer (list, optional): A list of channels to use in blocks of each + layer the amount of filters layer. Defaults + to None. + down_channels_per_block (list, optional): [description]. Defaults to None. + latent_dims (tuple, optional): The number of latent dimensions at each layer. + Defaults to (1, 1, 1, 1). + convs_per_block (int, optional): An integer specifying the number of convolutional + layers. Defaults to 3. Defaults to 2. + blocks_per_level (int, optional): An integer specifying the number of residual + blocks per level. Defaults to 1. + loss_kwargs (dict, optional): Dictionary of argument used by loss function. + Defaults to None. + ndims (int, optional): Specify whether to use 2 or 3 dimensions. Defaults to 2. + """ + super(HierarchicalProbabilisticUnet, self).__init__() + + base_channels = 24 + default_filters_per_layer = ( + base_channels, + 2 * base_channels, + 4 * base_channels, + 8 * base_channels, + 8 * base_channels, + 8 * base_channels, + 8 * base_channels, + 8 * base_channels, + ) + if filters_per_layer is None: + filters_per_layer = default_filters_per_layer + if down_channels_per_block is None: + down_channels_per_block = [int(i / 2) for i in filters_per_layer] + + self.prior = _HierarchicalCore( + input_channels=input_channels, + latent_dims=latent_dims, + channels_per_block=filters_per_layer, + down_channels_per_block=down_channels_per_block, + convs_per_block=convs_per_block, + blocks_per_level=blocks_per_level, + ndims=ndims, + ) + + self.posterior = _HierarchicalCore( + input_channels=input_channels + num_classes, + latent_dims=latent_dims, + channels_per_block=filters_per_layer, + down_channels_per_block=down_channels_per_block, + convs_per_block=convs_per_block, + blocks_per_level=blocks_per_level, + ndims=ndims, + ) + + self.fcomb = _StitchingDecoder( + latent_dims=latent_dims, + channels_per_block=filters_per_layer, + num_classes=num_classes, + down_channels_per_block=down_channels_per_block, + convs_per_block=convs_per_block, + blocks_per_level=blocks_per_level, + ndims=ndims, + ) + self.ndims = ndims + + self._cache = None + + self.loss_type = loss_type + self.loss_params = loss_params + + if self.loss_type == "geco": + self._rec_moving_avg = None + self._contour_moving_avg = None + self.register_buffer("_lambda", torch.zeros(2, requires_grad=False)) + + self._q_sample = None + self._q_sample_mean = None + self._p_sample = None + self._p_sample_z_q = None + self._p_sample_z_q_mean = None + + def forward(self, img, seg): + """Inserts all ops used during training into the graph exactly once. The first time this + method is called given the input pair (img, seg) all ops relevant for training are inserted + into the graph. Calling this method more than once does not re-insert the modules into the + graph (memoization), thus preventing multiple forward passes of submodules for the same + inputs. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + """ + + input_tensor = torch.cat([img, seg], axis=1) + + if not self._cache is None and torch.equal(self._cache, input_tensor): + # No need to recompute + return + + self._q_sample = self.posterior(input_tensor, mean=False) + self._q_sample_mean = self.posterior(input_tensor, mean=True) + self._p_sample = self.prior(img, mean=False, z_q=None) + self._p_sample_z_q = self.prior(img, z_q=self._q_sample["used_latents"]) + self._p_sample_z_q_mean = self.prior(img, z_q=self._q_sample_mean["used_latents"]) + self._cache = input_tensor + + def sample(self, img, mean=False, std_devs_from_mean=0.0, z_q=None): + """Sample a segmentation from the prior, given an input image. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + mean (bool, optional): A boolean or a list of booleans. If a boolean, it specifies + whether or not to use the distributions' means in ALL latent + scales. If a list, each bool therein specifies whether or not to + use the scale's mean. If False, the latents of the scale are + sampled. Defaults to False. + std_devs_from_mean (float|list, optional): A float or list of floats describing how far + from the mean should be sampled. Only at + scales where mean is True. Defaults to 0. + z_q (list, optional): If not None, z_q provides external latents to be used instead of + sampling them. This is used to employ posterior latents in the + prior during training. Therefore, if z_q is not None, the value + of `mean` is ignored. If z_q is None, either the distributions + mean is used (in case `mean` for the respective scale is True) or + else a sample from the distribution is drawn. Defaults to None. + + Returns: + torch.Tensor: A segmentation tensor of shape (b, num_classes, h, w). + """ + + prior_out = self.prior(img, mean, std_devs_from_mean, z_q) + encoder_features = prior_out["encoder_features"] + decoder_features = prior_out["decoder_features"] + return self.fcomb(encoder_features, decoder_features) + + def reconstruct(self, img, seg, mean=False): + """Reconstruct a segmentation using the posterior. + + Args: + img ([torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + mean (bool, optional): A boolean, specifying whether to sample from the full hierarchy + of the posterior or use the posterior means at each scale of the + hierarchy. Defaults to False. + + Returns: + torch.Tensor: A segmentation tensor of shape (b,num_classes,h,w). + """ + + # self.forward(img, seg) + if mean: + prior_out = self._p_sample_z_q_mean + else: + prior_out = self._p_sample_z_q + encoder_features = prior_out["encoder_features"] + decoder_features = prior_out["decoder_features"] + return self.fcomb(encoder_features, decoder_features) + + def kl(self, img, seg): + """Kullback-Leibler divergence between the posterior and the prior. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + + Returns: + dict: A dictionary with keys indexing the hierarchy's levels and corresponding + values holding the KL-term for each level (per batch). + """ + self.forward(img, seg) + posterior_out = self._q_sample + prior_out = self._p_sample_z_q + + q_dists = posterior_out["distributions"] + p_dists = prior_out["distributions"] + + kl = {} + for level, (p, q) in enumerate(zip(p_dists, q_dists)): + kl_per_pixel = torch.distributions.kl.kl_divergence(p, q) + + if self.ndims == 2: + kl_per_instance = torch.sum(kl_per_pixel, [1, 2]) + else: + kl_per_instance = torch.sum(kl_per_pixel, [1, 2, 3]) + + kl_clamp = img.shape[2:].numel() * 10 + kl_per_instance = kl_per_instance.clamp(0, kl_clamp) + kl[level] = torch.mean(kl_per_instance) + + return kl + + def topk_mask(self, score, k): + """Returns a mask for the top-k elements in score.""" + + values, _ = torch.topk(score, 1, axis=1) + _, indices = torch.topk(values, k, axis=0) + return torch.scatter_add( + torch.zeros(score.shape[0]).to(score.device), + 0, + indices.reshape(-1), + torch.ones(score.shape[0]).to(score.device), + ) + + def prepare_mask( + self, + mask, + top_k_percentage, + deterministic, + num_classes, + device, + batch_size, + n_pixels_in_batch, + xe, + ): + if mask is None or mask.sum() == 0: + mask = torch.ones(n_pixels_in_batch) + else: + # assert ( + # mask.shape == segm.shape + # ), f"The loss mask shape differs from the target shape: {mask.shape} vs. {segm.shape}." + mask = torch.reshape(mask, (-1,)) + mask = mask.to(device) + + if top_k_percentage is not None: + + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = int(n_pixels_in_batch * top_k_percentage) + + with torch.no_grad(): + norm_xe = xe / torch.sum(xe) + if deterministic: + score = torch.log(norm_xe) + else: + # TODO Gumbel trick + raise NotImplementedError("Still need to implement Gumbel trick") + + score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) + + top_k_mask = self.topk_mask(score, k_pixels) + top_k_mask = top_k_mask.to(device) + mask = mask * top_k_mask + + mask = mask.unsqueeze(1).repeat((1, num_classes)) + mask = ( + mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ) + + return mask + + def reconstruction_loss( + self, + img, + segm, + mask=None, + top_k_percentage=None, + deterministic=True, + ): + + criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + + reconstruction = self.reconstruct(img, segm) + + ##### + num_classes = reconstruction.shape[1] + y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) + t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) + n_pixels_in_batch = y_flat.shape[0] + batch_size = segm.shape[0] + + xe = criterion(input=y_flat, target=t_flat) + xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + + # If multiple masks supplied, compute a loss for each mask + if hasattr(mask, "__iter__"): + ce_sums = [] + ce_means = [] + masks = [] + for this_mask in mask: + this_mask = self.prepare_mask( + this_mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(this_mask * xe, axis=1) + ce_sums.append(torch.mean(ce_sum_per_instance, axis=0)) + ce_means.append(torch.sum(this_mask * xe) / torch.sum(this_mask)) + masks.append(this_mask) + + return ce_sums, ce_means, masks + + mask = self.prepare_mask( + mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(mask * xe, axis=1) + ce_sum = torch.mean(ce_sum_per_instance, axis=0) + ce_mean = torch.sum(mask * xe) / torch.sum(mask) + + return ce_sum, ce_mean, mask + + def loss(self, img, seg, mask=None): + """The full training objective, either ELBO or GECO. + + Args: + img (torch.Tensor): A tensor of shape (b, c, h, w). + seg (torch.Tensor): A tensor of shape (b, num_classes, h, w). + + Raises: + NotImplementedError: Raised if loss function supplied isn't implemented yet. + + Returns: + dict: A dictionary holding the loss (with key 'loss') + """ + kl_summaries = {} + kl_dict = self.kl(img, seg) + kl_sum = torch.sum(torch.stack([kl for _, kl in kl_dict.items()], axis=-1)) + for level, kl in kl_dict.items(): + kl_summaries[f"kl_{level}"] = kl + + top_k_percentage = None + if "top_k_percentage" in self.loss_params: + top_k_percentage = self.loss_params["top_k_percentage"] + + loss_mask = None + if "kappa" in self.loss_params: + reconstruction_threshold = self.loss_params["kappa"] + contour_threshold = None + if "kappa_contour" in self.loss_params and self.loss_params["kappa_contour"] is not None: + loss_mask = [None, mask] + contour_threshold = self.loss_params["kappa_contour"] + + # Here we use the posterior sample sampled above + _, rec_loss_mean, _ = self.reconstruction_loss( + img, + seg, + top_k_percentage=top_k_percentage, + mask=loss_mask, + ) + + # If using contour mask in loss, we get back those in a list. Unpack here. + if contour_threshold: + contour_loss = rec_loss_mean[1] + contour_loss_mean = rec_loss_mean[1] + reconstruction_loss = rec_loss_mean[0] + rec_loss_mean = rec_loss_mean[0] + else: + reconstruction_loss = rec_loss_mean + + if self.loss_type == "elbo": + + return { + "loss": reconstruction_loss + self.loss_params["beta"] * kl_sum, + "rec_loss": reconstruction_loss, + "kl_div": kl_sum, + } + elif self.loss_type == "geco": + + with torch.no_grad(): + + moving_avg_factor = 0.5 + + rl = rec_loss_mean.detach() + if self._rec_moving_avg is None: + self._rec_moving_avg = rl + else: + self._rec_moving_avg = self._rec_moving_avg * moving_avg_factor + rl * ( + 1 - moving_avg_factor + ) + + rc = self._rec_moving_avg - reconstruction_threshold + + cc = 0 + if contour_threshold: + cl = contour_loss_mean.detach() + if self._contour_moving_avg is None: + self._contour_moving_avg = rl + else: + self._contour_moving_avg = ( + self._contour_moving_avg * moving_avg_factor + + cl * (1 - moving_avg_factor) + ) + + cc = self._contour_moving_avg - contour_threshold + + lambda_lower = self.loss_params["clamp_rec"][0] + lambda_upper = self.loss_params["clamp_rec"][1] + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] + + self._lambda[0] = (torch.exp(rc) * self._lambda[0]).clamp(lambda_lower, lambda_upper) + if self._lambda[0].isnan(): self._lambda[0] = lambda_upper + if contour_threshold: + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] + + self._lambda[1] = (torch.exp(cc) * self._lambda[1]).clamp(lambda_lower_contour, lambda_upper_contour) + if self._lambda[1].isnan(): self._lambda[1] = lambda_upper_contour + + # pylint: disable=access-member-before-definition + loss = (self._lambda[0] * reconstruction_loss) + kl_sum + + result = { + "loss": loss, + "rec_loss": reconstruction_loss, + "kl_div": kl_sum, + "lambda_rec": self._lambda[0], + "moving_avg": self._rec_moving_avg, + "reconstruction_threshold": reconstruction_threshold, + "rec_constraint": rc, + } + + if contour_threshold is not None: + result["loss"] = result["loss"] + (self._lambda[1] * contour_loss) + result["contour_loss"] = contour_loss + result["contour_threshold"] = contour_threshold + result["contour_constraint"] = cc + result["moving_avg_contour"] = self._contour_moving_avg + result["lambda_contour"] = self._lambda[1] + result = {**result, **kl_summaries} + + return result + + else: + raise NotImplementedError("Loss must be 'elbo' or 'geco'") diff --git a/platipy/imaging/cnn/lidc_dataset.py b/platipy/imaging/cnn/lidc_dataset.py new file mode 100644 index 00000000..7a314be0 --- /dev/null +++ b/platipy/imaging/cnn/lidc_dataset.py @@ -0,0 +1,225 @@ +# Copyright 2022 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import pickle +import os + +import SimpleITK as sitk +import numpy as np +import matplotlib.pyplot as plt + +import torch + +from imgaug import augmenters as iaa +from imgaug.augmentables.segmaps import SegmentationMapsOnImage + +from platipy.imaging import ImageVisualiser + +def prepare_transforms(): + + sometimes = lambda aug: iaa.Sometimes(0.5, aug) + + seq = iaa.Sequential( + [ + sometimes( + iaa.Affine( + scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, + translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, + rotate=(-15, 15), + shear=(-8, 8), + cval=0, + ) + ), + # execute 0 to 2 of the following (less important) augmenters per image + iaa.SomeOf( + (0, 2), + [ + iaa.OneOf( + [ + iaa.GaussianBlur((0, 1.5)), + iaa.AverageBlur(k=(3, 5)), + ] + ), + sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1))), + ], + random_order=True, + ), + sometimes(iaa.CoarseDropout((0.03, 0.15), size_percent=(0.02, 0.1))), + ], + random_order=True, + ) + + return seq + + +class LIDCDataset(torch.utils.data.Dataset): + """PyTorch Dataset for processing LIDC data""" + + def __init__( + self, + working_dir, + case_ids=None, + pickle_path="lidc.pickle", + augment_on_fly=True, + ): + """Prepare a dataset from Nifti images/labels + + Args: + data (list): List of dict's where each item contains keys: "image" and "label". Values + are paths to the Nifti file. "label" may be a list where each item is a path to one + observer. + working_dir (str|path): Working directory where to write prepared files. + """ + + self.transforms = None + if augment_on_fly: + self.transforms = prepare_transforms() + self.slices = [] + self.working_dir = Path(working_dir) + + self.img_dir = self.working_dir.joinpath("img") + self.label_dir = self.working_dir.joinpath("label") + self.contour_mask_dir = self.working_dir.joinpath("contour_mask") + self.snap_dir = self.working_dir.joinpath("snapshots") + + self.img_dir.mkdir(exist_ok=True, parents=True) + self.label_dir.mkdir(exist_ok=True, parents=True) + self.contour_mask_dir.mkdir(exist_ok=True, parents=True) + self.snap_dir.mkdir(exist_ok=True, parents=True) + + # If data doesn't already exist, unpickle data and place into directory + if len(list(self.img_dir.glob("*"))) == 0: + pickle_path = Path(pickle_path) + + max_bytes = 2**31 - 1 + data = {} + + print("Loading file", pickle_path) + bytes_in = bytearray(0) + input_size = os.path.getsize(pickle_path) + with open(pickle_path, 'rb') as f_in: + for _ in range(0, input_size, max_bytes): + bytes_in += f_in.read(max_bytes) + new_data = pickle.loads(bytes_in) + data.update(new_data) + + for k,i in data.items(): + + pat_id = k.split("_")[0] + slice_id = k.split("_")[1].replace("slice", "") + + i["pixel_spacing"] = [float(a) for a in i["pixel_spacing"]] + + img_file = self.img_dir.joinpath(f"{pat_id}_{slice_id}.npy") + np.save(img_file, i["image"]) + + intersection = None + union = None + vis = ImageVisualiser(sitk.GetImageFromArray(np.expand_dims(i["image"], axis=0)), axis="z", window=[0,1]) + for obs, mask in enumerate(i["masks"]): + + vis.add_contour(sitk.GetImageFromArray(np.expand_dims(mask, axis=0)), name=f"{obs}") + label_file = self.label_dir.joinpath(f"{pat_id}_{slice_id}_{obs}.npy") + np.save(label_file, mask) + + mask = mask.astype(int) + + if intersection is None: + intersection = np.copy(mask) + else: + intersection += mask + + if union is None: + union = np.copy(mask) + else: + union += mask + + intersection[intersection>1] = 1 + union[union cases[case]["slices"]: + cases[case]["slices"] = z.item() + if not observer in cases[case]["observers"]: + cases[case]["observers"].append(observer) + + metrics = {"JI": [], "DSC": [], "HD": [], "ASD": []} + for case in cases: + + img_arrs = [] + pred_arrs = [] + slices = [] + for z in range(cases[case]["slices"] + 1): + img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + pred_file = self.validation_directory.joinpath(f"pred_{case}_{z}.npy") + if img_file.exists(): + img_arrs.append(np.load(img_file)) + pred_arrs.append(np.load(pred_file)) + slices.append(z) + + if len(slices) < 5: + # Likely initial sanity check + continue + + img_arr = np.stack(img_arrs) + img = sitk.GetImageFromArray(img_arr) + img.SetSpacing(self.hparams.spacing) + + pred_arr = np.stack(pred_arrs) + pred = sitk.GetImageFromArray(pred_arr) + pred = sitk.Cast(pred, sitk.sitkUInt8) + pred = postprocess_mask(pred) + pred.CopyInformation(img) + sitk.WriteImage(pred, f"val_pred_{case}.nii.gz") + + color_dict = {} + obs_dict = {} + + for _, observer in enumerate(cases[case]["observers"]): + mask_arrs = [] + for z in slices: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + + mask_arrs.append(np.load(mask_file)) + + mask_arr = np.stack(mask_arrs) + mask = sitk.GetImageFromArray(mask_arr) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask.CopyInformation(img) + obs_dict[f"manual_{observer}"] = mask + color_dict[f"manual_{observer}"] = [0.7, 0.2, 0.2] + + com = None + try: + com = get_com(mask) + except: + com = [int(i / 2) for i in img.GetSize()] + print(com) + + img_vis = ImageVisualiser(img, cut=com, figure_size_in=16) + # img_vis.set_limits_from_label(mask, expansion=[0, 0, 0]) + + contour_dict = {**obs_dict} + contour_dict["pred"] = pred + color_dict["pred"] = [0.2, 0.4, 0.8] + + img_vis.add_contour(contour_dict, color=color_dict) + fig = img_vis.show() + figure_path = f"valid_{case}.png" + fig.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + case_metrics = get_metrics(pred, mask) + for m in case_metrics: + metrics[m].append(case_metrics[m]) + + for m in metrics: + self.log( + m, + np.array(metrics[m]).mean(), + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, +# batch_size=self.hparams.batch_size, + ) diff --git a/platipy/imaging/cnn/metrics.py b/platipy/imaging/cnn/metrics.py new file mode 100644 index 00000000..cd4f3494 --- /dev/null +++ b/platipy/imaging/cnn/metrics.py @@ -0,0 +1,155 @@ +import collections +import math + +import SimpleITK as sitk +import numpy as np + +from platipy.imaging.label.comparison import compute_surface_dsc, compute_metric_dsc +from platipy.imaging.label.utils import get_union_mask, get_intersection_mask + + +def probabilistic_dice(gt_labels, sampled_labels, dsc_type="dsc", tau=3): + + gt_union = get_union_mask(gt_labels) + gt_intersection = get_intersection_mask(gt_labels) + + st_union = get_union_mask(sampled_labels) + st_intersection = get_intersection_mask(sampled_labels) + + if dsc_type == "dsc": + dsc_union = compute_metric_dsc(gt_union, st_union) + dsc_intersection = compute_metric_dsc(gt_intersection, st_intersection) + + if dsc_type == "sdsc": + dsc_union = compute_surface_dsc(gt_union, st_union, tau=tau) + dsc_intersection = compute_surface_dsc(gt_intersection, st_intersection, tau=tau) + + return (dsc_union + dsc_intersection) / 2 + + +def probabilistic_surface_dice(gt_labels, sampled_labels, sd_range=3, tau=0): + + if isinstance(gt_labels, dict): + gt_labels = [gt_labels[l] for l in gt_labels] + + if isinstance(sampled_labels, dict): + sampled_labels = [sampled_labels[l] for l in sampled_labels] + + binary_contour_filter = sitk.BinaryContourImageFilter() + binary_contour_filter.FullyConnectedOff() + summed = None + for mask in gt_labels: + + if summed is None: + summed = mask + + else: + summed += mask + + intersection = summed >= 1 + union = summed >= 5 + + mask_mean = summed >= 3 + intersection_minus_mean = intersection - mask_mean + mean_minus_union = mask_mean - union + + contour_i = binary_contour_filter.Execute(intersection) + contour_u = binary_contour_filter.Execute(union) + contour_mean = binary_contour_filter.Execute(mask_mean) + + dist_to_i = sitk.SignedMaurerDistanceMap( + contour_i, useImageSpacing=True, squaredDistance=False + ) + + dist_to_u = sitk.SignedMaurerDistanceMap( + contour_u, useImageSpacing=True, squaredDistance=False + ) + + dist_to_mean = sitk.SignedMaurerDistanceMap( + contour_mean, useImageSpacing=True, squaredDistance=False + ) + + mean = 0 + sd = 1 / sd_range + max_agg = np.pi * sd + + dist_sum = dist_to_mean + dist_to_i + dist_ratio_neg = dist_to_mean / dist_sum + + dist_ratio_arr = sitk.GetArrayFromImage(dist_ratio_neg) + + dist_ratio_arr = (np.pi * sd) * np.exp(-0.5 * ((dist_ratio_arr - mean) / sd) ** 2) + dist_ratio_arr = dist_ratio_arr / max_agg / 2 # Normalise + dist_ratio_arr[sitk.GetArrayFromImage(intersection_minus_mean) == 0] = 0 + dist_ratio_neg = sitk.GetImageFromArray(dist_ratio_arr) + dist_ratio_neg.CopyInformation(dist_sum) + + dist_sum = dist_to_mean + dist_to_u + dist_ratio_pos = dist_to_u / dist_sum + + dist_ratio_arr = sitk.GetArrayFromImage(dist_ratio_pos) + + dist_ratio_arr = (np.pi * sd) * np.exp(-0.5 * ((dist_ratio_arr - mean) / sd) ** 2) + dist_ratio_arr = (dist_ratio_arr / max_agg / 2) + 0.5 # Normalise + dist_ratio_arr[sitk.GetArrayFromImage(mean_minus_union) == 0] = 0 + dist_ratio_arr[sitk.GetArrayFromImage(union) == 1] = 1 + dist_ratio_pos = sitk.GetImageFromArray(dist_ratio_arr) + dist_ratio_pos.CopyInformation(dist_sum) + + dist_ratio = dist_ratio_neg + dist_ratio_pos + + sample_count = math.floor(len(sampled_labels) / 2) + + ranges = {} + range_masks = {} + start_mask = None + for pr in np.linspace(0.5, 1, sample_count + 1): + next_mask = dist_ratio >= pr + next_contour = binary_contour_filter.Execute(next_mask) + + if start_mask is None: + ranges[pr] = next_contour + else: + ranges[pr] = ((start_mask - next_mask) + start_contour + next_contour) > 0 + + range_masks[pr] = next_mask + + start_mask = next_mask + start_contour = binary_contour_filter.Execute(start_mask) + + start_mask = None + for pr in np.linspace(0.5, 0.000001, sample_count + 1): + next_mask = dist_ratio >= pr + next_contour = binary_contour_filter.Execute(next_mask) + + if start_mask is None: + ranges[pr] = next_contour + else: + ranges[pr] = ((next_mask - start_mask) + start_contour + next_contour) > 0 + + range_masks[pr] = next_mask + + start_mask = next_mask + start_contour = binary_contour_filter.Execute(start_mask) + + ranges = collections.OrderedDict(sorted(ranges.items())) + range_masks = collections.OrderedDict(sorted(range_masks.items())) + + result = 0 + for idx, r in enumerate(ranges): + auto_mask = sampled_labels[idx] + auto_contour = binary_contour_filter.Execute(auto_mask > 0) + + dist_to_range = sitk.SignedMaurerDistanceMap( + ranges[r], useImageSpacing=True, squaredDistance=False + ) + + auto_intersection = sitk.GetArrayFromImage(auto_contour * (dist_to_range <= tau)).sum() + + this_result = auto_intersection / sitk.GetArrayFromImage(auto_contour).sum() + if np.isnan(this_result): + this_result = 0 + + result += this_result + + return result / len(ranges) diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py new file mode 100644 index 00000000..219a24df --- /dev/null +++ b/platipy/imaging/cnn/prob_unet.py @@ -0,0 +1,619 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of this work are derived from: +# https://github.com/stefanknegt/Probabilistic-Unet-Pytorch +# which is released under the Apache Licence 2.0 + +import torch +from torch.distributions import Normal, Independent, kl + +from platipy.imaging.cnn.unet import UNet, Conv, init_weights, conv_nd + + +class Encoder(torch.nn.Module): + """Encoder part of the probabilistic UNet""" + + def __init__( + self, input_channels, filters_per_layer=[64 * (2**x) for x in range(5)], ndims=2, dropout_probability=None + ): + super(Encoder, self).__init__() + + layers = [] + for idx, layer_filters in enumerate(filters_per_layer): + + input_filters = input_channels if idx == 0 else output_filters + output_filters = layer_filters + + down_sample = 0 if idx == 0 else -2 + + layers.append( + Conv( + input_filters, + output_filters, + up_down_sample=down_sample, + ndims=ndims, + dropout_probability=dropout_probability, + ) + ) + + self.layers = torch.nn.Sequential(*layers) + + self.layers.apply(init_weights) + + def forward(self, x): + + return self.layers(x) + + +class AxisAlignedConvGaussian(torch.nn.Module): + def __init__( + self, + input_channels, + filters_per_layer=[64 * (2**x) for x in range(5)], + latent_dim=2, + ndims=2, + dropout_probability=0.0 + ): + + super(AxisAlignedConvGaussian, self).__init__() + + self.latent_dim = latent_dim + + self.encoder = Encoder(input_channels, filters_per_layer, ndims=ndims, dropout_probability=dropout_probability) + + self.final = conv_nd( + in_channels=filters_per_layer[-1], + out_channels=2 * self.latent_dim, + kernel_size=1, + stride=1, + ndims=ndims, + ) + + self.ndims = ndims + + self.final.apply(init_weights) + + def forward(self, img, seg=None): + """Forward pass through the network + + Args: + img (torch.Tensor): The image to be passed through. + seg (torch.Tensor, optional): The segmentation mask to use in the case of the prior + network. Defaults to None. + + Returns: + torch.distributions.distribution.Distribution: The distribution output + """ + + x = img + if seg is not None: + # seg = torch.unsqueeze(seg, dim=1) + x = torch.cat((img, seg), dim=1) + + encoding = self.encoder(x) + + # We only want the mean of the resulting hxw image + encoding = torch.mean(encoding, dim=2, keepdim=True) + encoding = torch.mean(encoding, dim=3, keepdim=True) + if self.ndims == 3: + encoding = torch.mean(encoding, dim=4, keepdim=True) + + # Convert encoding to 2 x latent dim and split up for mu and log_sigma + mu_log_sigma = self.final(encoding) + + # We squeeze the second dimension twice, since otherwise it won't work when batch size is + # equal to 1 + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + if self.ndims == 3: + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + + mu = mu_log_sigma[:, : self.latent_dim].clamp(-1000, 1000) + log_sigma = mu_log_sigma[:, self.latent_dim :].clamp(-10, 10) + + # This is a multivariate normal with diagonal covariance matrix sigma + # https://github.com/pytorch/pytorch/pull/11178 + dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) + + return dist + + +class Fcomb(torch.nn.Module): + """ + A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken + from the latent space, and output of the UNet (the feature map) by concatenating them along + their channel axis. + """ + + def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=2): + super(Fcomb, self).__init__() + + layers = [] + + # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the + # last layer + layers.append( + conv_nd( + in_channels=filters_per_layer[0] + latent_dim, + out_channels=filters_per_layer[0], + kernel_size=1, + ndims=ndims, + ) + ) + layers.append(torch.nn.ReLU(inplace=True)) + + for _ in range(no_convs_fcomb - 2): + layers.append( + conv_nd( + in_channels=filters_per_layer[0], + out_channels=filters_per_layer[0], + kernel_size=1, + ndims=ndims, + ) + ) + layers.append(torch.nn.ReLU(inplace=True)) + + self.layers = torch.nn.Sequential(*layers) + + self.last_layer = conv_nd( + in_channels=filters_per_layer[0], out_channels=num_classes, kernel_size=1, ndims=ndims + ) + + self.layers.apply(init_weights) + self.last_layer.apply(init_weights) + + self.ndims = ndims + + def forward(self, feature_map, z): + + z = torch.unsqueeze(z, 2).expand(-1, -1, feature_map.shape[2]) + z = torch.unsqueeze(z, 3).expand(-1, -1, -1, feature_map.shape[3]) + if self.ndims == 3: + z = torch.unsqueeze(z, 4).expand(-1, -1, -1, -1, feature_map.shape[4]) + + # Concatenate the feature map (output of the UNet) and the sample taken from the latent + # space + feature_map = torch.cat((feature_map, z), dim=1) + output = self.layers(feature_map) + return self.last_layer(output) + + +class ProbabilisticUnet(torch.nn.Module): + """ + A probabilistic UNet implementation + (https://papers.nips.cc/paper/2018/file/473447ac58e1cd7e96172575f48dca3b-Paper.pdf) + + input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) + num_classes: the number of classes to predict + num_filters: is a list consisint of the amount of filters layer + latent_dim: dimension of the latent space + no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior + """ + + def __init__( + self, + input_channels=1, + num_classes=2, + filters_per_layer=[64 * (2**x) for x in range(5)], + latent_dim=6, + no_convs_fcomb=4, + loss_type="elbo", + loss_params={"beta": 1}, + ndims=2, + dropout_probability=0.0, + use_structure_context=False + ): + super(ProbabilisticUnet, self).__init__() + + self.num_classes = num_classes + self.no_convs_per_block = 3 + self.no_convs_fcomb = no_convs_fcomb + self.initializers = {"w": "he_normal", "b": "normal"} + self.z_prior_sample = 0 + self.latent_dim = latent_dim + self.use_structure_context = use_structure_context + + unet_input_channels = input_channels + if use_structure_context: + unet_input_channels = unet_input_channels + num_classes + + self.unet = UNet( + unet_input_channels, + num_classes, + filters_per_layer, + final_layer=False, + dropout_probability=dropout_probability, + ndims=ndims, + ) + + self.prior = AxisAlignedConvGaussian( + unet_input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + ) + + post_channels = input_channels + num_classes + if use_structure_context: + post_channels = input_channels + (num_classes * 2) + + self.posterior = AxisAlignedConvGaussian( + post_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + ) + self.fcomb = Fcomb(filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=ndims) + + self.loss_type = loss_type + self.loss_params = loss_params + + self.posterior_latent_space = None + self.prior_latent_space = None + self.unet_features = None + + if self.loss_type == "geco": + self._rec_moving_avg = None + self._contour_moving_avg = None + self.register_buffer("_lambda", torch.ones(2, requires_grad=False)) + + self.register_buffer("_pos_weight", torch.ones(num_classes, requires_grad=False)) + + def forward(self, img, seg=None, cseg=None, training=False): + """ + Construct prior latent space for patch and run patch through UNet, + in case training is True also construct posterior latent space + """ + + if self.use_structure_context: + if cseg is None: + raise ValueError("Structure context is enabled, but no context segmentation mask provided") + + img = torch.cat((img, cseg), dim=1) + + if training: + self.posterior_latent_space = self.posterior.forward(img, seg=seg) + + self.prior_latent_space = self.prior.forward(img) + + self.unet_features = self.unet.forward(img) + + def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): + """ + Sample a segmentation by reconstructing from a prior sample + and combining this with UNet features + """ + + latent_space = self.prior_latent_space + + if testing: + if use_mean: + z_prior = latent_space.base_dist.loc + elif not sample_x_stddev_from_mean is None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + latent_space.base_dist.stddev.device + ) + z_prior = self.prior_latent_space.base_dist.loc + ( + latent_space.base_dist.scale * sample_x_stddev_from_mean + ) + else: + z_prior = latent_space.sample() + self.z_prior_sample = z_prior + else: + z_prior = latent_space.rsample() + self.z_prior_sample = z_prior + + return self.fcomb.forward(self.unet_features, z_prior) + + def reconstruct(self, use_posterior_mean=False, z_posterior=None, sample_x_stddev_from_mean=None): + """ + Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet + feature map + + use_posterior_mean: use posterior_mean instead of sampling z_q + """ + if use_posterior_mean: + z_posterior = self.posterior_latent_space.mean + elif sample_x_stddev_from_mean is not None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + self.posterior_latent_space.base_dist.stddev.device + ) + z_posterior = self.posterior_latent_space.base_dist.loc + ( + self.posterior_latent_space.base_dist.scale * sample_x_stddev_from_mean + ) + else: + if z_posterior is None: + z_posterior = self.posterior_latent_space.rsample() + return self.fcomb.forward(self.unet_features, z_posterior) + + def kl_divergence(self): + """ + Calculate the KL divergence between the posterior and prior KL(Q||P) + """ + + #if self.prior_latent_space is None: + # device = self.posterior_latent_space.base_dist.stddev.device + # dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim).to(device)), 1) + # kl_div = kl.kl_divergence(self.posterior_latent_space, dist) + #else: + print(self.posterior_latent_space) + print(self.prior_latent_space) + kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) + + return kl_div + + def topk_mask(self, score, k): + """Returns a mask for the top-k elements in score.""" + + values, _ = torch.topk(score, 1, axis=1) + _, indices = torch.topk(values, k, axis=0) + return torch.scatter_add( + torch.zeros(score.shape[0]).to(score.device), + 0, + indices.reshape(-1), + torch.ones(score.shape[0]).to(score.device), + ) + + def prepare_mask( + self, + mask, + top_k_percentage, + deterministic, + num_classes, + device, + batch_size, + n_pixels_in_batch, + xe, + ): + if mask is None or mask.sum() == 0: + mask = torch.ones(n_pixels_in_batch) + mask = mask.to(device) + + if top_k_percentage is not None: + + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = int(n_pixels_in_batch * top_k_percentage) + + with torch.no_grad(): + norm_xe = xe / torch.sum(xe) + if deterministic: + score = torch.log(norm_xe) + else: + # TODO Gumbel trick + raise NotImplementedError("Still need to implement Gumbel trick") + + score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) + + top_k_mask = self.topk_mask(score, k_pixels) + top_k_mask = top_k_mask.to(device) + mask = mask * top_k_mask + + else: + mask = torch.reshape(mask, (-1,)) + + mask = mask.unsqueeze(1).repeat((1, num_classes)) + + mask = ( + mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ) + + return mask + + def reconstruction_loss( + self, + segm, + z_posterior=None, + mask=None, + top_k_percentage=None, + deterministic=True, + ): + criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + # criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) + + if z_posterior is None: + z_posterior = self.posterior_latent_space.rsample() + + reconstruction = self.reconstruct(use_posterior_mean=False, z_posterior=z_posterior) + + #criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + + # loss = criterion(input=reconstruction, target=segm) + # return loss, None, None + ##### + num_classes = reconstruction.shape[1] + y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) + t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) + n_pixels_in_batch = y_flat.shape[0] + batch_size = segm.shape[0] + + # pos_class_count = t_flat.sum(axis=0) / batch_size + # neg_class_count = torch.logical_not(t_flat).sum(axis=0) / batch_size + # self._pos_weight = ( + # self._pos_weight * 0.5 + (neg_class_count / pos_class_count).clamp(0, 10000) * 0.5 + # ) + + # criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) + xe = criterion(input=y_flat, target=t_flat) + xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + + # If multiple masks supplied, compute a loss for each mask + if hasattr(mask, "__iter__"): + ce_sums = [] + ce_means = [] + masks = [] + for this_mask in mask: + this_mask = self.prepare_mask( + this_mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(this_mask * xe, axis=1) + ce_sums.append(torch.mean(ce_sum_per_instance, axis=0)) + ce_means.append(torch.sum(this_mask * xe) / torch.sum(this_mask)) + masks.append(this_mask) + + return ce_sums, ce_means, masks + + mask = self.prepare_mask( + mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(mask * xe, axis=1) + ce_sum = torch.mean(ce_sum_per_instance, axis=0) + ce_mean = torch.sum(mask * xe) / torch.sum(mask) + + return ce_sum, ce_mean, mask + + def loss(self, segm, mask=None, beta=None): + """ + Calculate the evidence lower bound of the log-likelihood of P(Y|X) + """ + + z_posterior = self.posterior_latent_space.rsample() + + kl_div = torch.mean(self.kl_divergence()) + # kl_div = torch.clamp(kl_div, 0.0, 100.0) + + top_k_percentage = None + if "top_k_percentage" in self.loss_params: + top_k_percentage = self.loss_params["top_k_percentage"] + + loss_mask = None + contour_threshold = None + if self.loss_type == "geco": + reconstruction_threshold = self.loss_params["kappa"] + if ( + "kappa_contour" in self.loss_params + and self.loss_params["kappa_contour"] is not None + ): + loss_mask = [None, mask] + contour_threshold = self.loss_params["kappa_contour"] + + # Here we use the posterior sample sampled above + rl_sum, rec_loss_mean, _ = self.reconstruction_loss( + segm, + z_posterior=z_posterior, + top_k_percentage=top_k_percentage, + mask=loss_mask, + ) + + # If using contour mask in loss, we get back those in a list. Unpack here. + if contour_threshold: + contour_loss = rl_sum[1] + contour_loss_mean = rec_loss_mean[1] + reconstruction_loss = rl_sum[0] + rec_loss_mean = rec_loss_mean[0] + else: + reconstruction_loss = rl_sum + + if self.loss_type == "elbo": + if beta is None: + beta = self.loss_params["beta"] + + return { + "loss": reconstruction_loss + beta * kl_div, + "rec_loss": reconstruction_loss, + "kl_div": kl_div, + "beta": beta, + } + elif self.loss_type == "geco": + + rec_geco_step_size = self.loss_params["rec_geco_step_size"] + + with torch.no_grad(): + + moving_avg_factor = 0.5 + + rl = rec_loss_mean.detach() + if self._rec_moving_avg is None: + self._rec_moving_avg = rl + else: + self._rec_moving_avg = self._rec_moving_avg * moving_avg_factor + rl * ( + 1 - moving_avg_factor + ) + + rc = self._rec_moving_avg - reconstruction_threshold + + cc = 0 + if contour_threshold: + cl = contour_loss_mean.detach() + if self._contour_moving_avg is None: + self._contour_moving_avg = rl + else: + self._contour_moving_avg = ( + self._contour_moving_avg * moving_avg_factor + + cl * (1 - moving_avg_factor) + ) + + cc = self._contour_moving_avg - contour_threshold + + lambda_lower = self.loss_params["clamp_rec"][0] + lambda_upper = self.loss_params["clamp_rec"][1] + + self._lambda[0] = (torch.exp(rc * rec_geco_step_size) * self._lambda[0]).clamp( + lambda_lower, lambda_upper + ) + # self._lambda[0] = (rc * self._lambda[0]).clamp(lambda_lower, lambda_upper) + if self._lambda[0].isnan(): + self._lambda[0] = lambda_upper + if contour_threshold: + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] + + self._lambda[1] = (torch.exp(cc * rec_geco_step_size) * self._lambda[1]).clamp( + lambda_lower_contour, lambda_upper_contour + ) + # self._lambda[1] = (cc * self._lambda[1]).clamp( + # lambda_lower_contour, lambda_upper_contour + # ) + if self._lambda[1].isnan(): + self._lambda[1] = lambda_upper_contour + + # pylint: disable=access-member-before-definition + loss = (self._lambda[0] * reconstruction_loss) + kl_div + + result = { + "loss": loss, + "rec_loss": reconstruction_loss, + "kl_div": kl_div, + "lambda_rec": self._lambda[0], + "moving_avg": self._rec_moving_avg, + "reconstruction_threshold": reconstruction_threshold, + "rec_constraint": rc, + } + + if contour_threshold is not None: + result["loss"] = result["loss"] + (self._lambda[1] * contour_loss) + result["contour_loss"] = contour_loss + result["contour_threshold"] = contour_threshold + result["contour_constraint"] = cc + result["moving_avg_contour"] = self._contour_moving_avg + result["lambda_contour"] = self._lambda[1] + + return result + + else: + raise NotImplementedError("Loss must be 'elbo' or 'geco'") diff --git a/platipy/imaging/cnn/prob_unet_debug.py b/platipy/imaging/cnn/prob_unet_debug.py new file mode 100644 index 00000000..09dda08b --- /dev/null +++ b/platipy/imaging/cnn/prob_unet_debug.py @@ -0,0 +1,647 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of this work are derived from: +# https://github.com/stefanknegt/Probabilistic-Unet-Pytorch +# which is released under the Apache Licence 2.0 + +import torch +import numpy as np +from torch.distributions import Normal, Independent, kl + +from platipy.imaging.cnn.unet import UNet, Conv, init_weights, conv_nd + + +class Encoder(torch.nn.Module): + """Encoder part of the probabilistic UNet""" + + def __init__( + self, input_channels, filters_per_layer=[64 * (2**x) for x in range(5)], ndims=2, dropout_probability=None + ): + super(Encoder, self).__init__() + + layers = [] + for idx, layer_filters in enumerate(filters_per_layer): + + input_filters = input_channels if idx == 0 else output_filters + output_filters = layer_filters + + down_sample = 0 if idx == 0 else -2 + + layers.append( + Conv( + input_filters, + output_filters, + up_down_sample=down_sample, + ndims=ndims, + dropout_probability=dropout_probability, + ) + ) + + self.layers = torch.nn.Sequential(*layers) + + self.layers.apply(init_weights) + + def forward(self, x): + + return self.layers(x) + + +class AxisAlignedConvGaussian(torch.nn.Module): + def __init__( + self, + input_channels, + filters_per_layer=[64 * (2**x) for x in range(5)], + latent_dim=2, + ndims=2, + dropout_probability=0.0 + ): + + super(AxisAlignedConvGaussian, self).__init__() + + self.latent_dim = latent_dim + + self.encoder = Encoder(input_channels, filters_per_layer, ndims=ndims, dropout_probability=dropout_probability) + + self.final = conv_nd( + in_channels=filters_per_layer[-1], + out_channels=2 * self.latent_dim, + kernel_size=1, + stride=1, + ndims=ndims, + ) + + self.ndims = ndims + + self.final.apply(init_weights) + + def forward(self, img, seg=None): + """Forward pass through the network + + Args: + img (torch.Tensor): The image to be passed through. + seg (torch.Tensor, optional): The segmentation mask to use in the case of the prior + network. Defaults to None. + + Returns: + torch.distributions.distribution.Distribution: The distribution output + """ + + x = img + if seg is not None: + # seg = torch.unsqueeze(seg, dim=1) + x = torch.cat((img, seg), dim=1) + + encoding = self.encoder(x) + + # We only want the mean of the resulting hxw image + encoding = torch.mean(encoding, dim=2, keepdim=True) + encoding = torch.mean(encoding, dim=3, keepdim=True) + if self.ndims == 3: + encoding = torch.mean(encoding, dim=4, keepdim=True) + + # Convert encoding to 2 x latent dim and split up for mu and log_sigma + mu_log_sigma = self.final(encoding) + + # We squeeze the second dimension twice, since otherwise it won't work when batch size is + # equal to 1 + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + if self.ndims == 3: + mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) + + mu = mu_log_sigma[:, : self.latent_dim].clamp(-1000, 1000) + log_sigma = mu_log_sigma[:, self.latent_dim :].clamp(-10, 10) + + # This is a multivariate normal with diagonal covariance matrix sigma + # https://github.com/pytorch/pytorch/pull/11178 + dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)), 1) + + return dist + + +class Fcomb(torch.nn.Module): + """ + A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken + from the latent space, and output of the UNet (the feature map) by concatenating them along + their channel axis. + """ + + def __init__(self, filters_per_layer, latent_dim, num_classes, no_convs_fcomb, ndims=2): + super(Fcomb, self).__init__() + + layers = [] + + # Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the + # last layer + layers.append( + conv_nd( + in_channels=filters_per_layer[0] + latent_dim, + out_channels=filters_per_layer[0], + kernel_size=1, + ndims=ndims, + ) + ) + layers.append(torch.nn.ReLU(inplace=True)) + + for _ in range(no_convs_fcomb - 2): + layers.append( + conv_nd( + in_channels=filters_per_layer[0], + out_channels=filters_per_layer[0], + kernel_size=1, + ndims=ndims, + ) + ) + layers.append(torch.nn.ReLU(inplace=True)) + + self.layers = torch.nn.Sequential(*layers) + + self.last_layer = conv_nd( + in_channels=filters_per_layer[0], out_channels=num_classes, kernel_size=1, ndims=ndims + ) + + self.layers.apply(init_weights) + self.last_layer.apply(init_weights) + + self.ndims = ndims + + def forward(self, feature_map, z): + + #z = torch.unsqueeze(z, 2).expand(-1, -1, feature_map.shape[2]) + #z = torch.unsqueeze(z, 3).expand(-1, -1, -1, feature_map.shape[3]) + #if self.ndims == 3: + # z = torch.unsqueeze(z, 4).expand(-1, -1, -1, -1, feature_map.shape[4]) + + # Concatenate the feature map (output of the UNet) and the sample taken from the latent + # space + # feature_map = torch.cat((feature_map, z), dim=1) + output = self.layers(feature_map) + return self.last_layer(output) + + +class ProbabilisticUnet(torch.nn.Module): + """ + A probabilistic UNet implementation + (https://papers.nips.cc/paper/2018/file/473447ac58e1cd7e96172575f48dca3b-Paper.pdf) + + input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) + num_classes: the number of classes to predict + num_filters: is a list consisint of the amount of filters layer + latent_dim: dimension of the latent space + no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior + """ + + def __init__( + self, + input_channels=1, + num_classes=2, + filters_per_layer=[64 * (2**x) for x in range(5)], + latent_dim=6, + no_convs_fcomb=4, + loss_type="elbo", + loss_params={"beta": 1}, + ndims=2, + dropout_probability=0.0, + use_structure_context=False + ): + super(ProbabilisticUnet, self).__init__() + + self.num_classes = num_classes + self.no_convs_per_block = 3 + self.no_convs_fcomb = no_convs_fcomb + self.initializers = {"w": "he_normal", "b": "normal"} + self.z_prior_sample = 0 + self.latent_dim = latent_dim + self.use_structure_context = use_structure_context + + unet_input_channels = input_channels + if use_structure_context: + unet_input_channels = unet_input_channels + num_classes + + self.unet = UNet( + unet_input_channels, + num_classes, + filters_per_layer, + final_layer=False, + dropout_probability=dropout_probability, + ndims=ndims, + ) + self.prior = None + if not use_structure_context: + self.prior = AxisAlignedConvGaussian( + input_channels, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + ) + #self.posterior = AxisAlignedConvGaussian( + # input_channels + num_classes, filters_per_layer, latent_dim, ndims=ndims, dropout_probability=0.0 + #) + self.fcomb = Fcomb(filters_per_layer, 0, num_classes, no_convs_fcomb, ndims=ndims) + + self.loss_type = loss_type + self.loss_params = loss_params + + self.posterior_latent_space = None + self.prior_latent_space = None + self.unet_features = None + + if self.loss_type == "geco": + self._rec_moving_avg = None + self._contour_moving_avg = None + self.register_buffer("_lambda", torch.ones(2, requires_grad=False)) + + self.register_buffer("_pos_weight", torch.ones(num_classes, requires_grad=False)) + + def forward(self, img, seg=None, training=False): + """ + Construct prior latent space for patch and run patch through UNet, + in case training is True also construct posterior latent space + """ + #if training or self.prior is None: + # self.posterior_latent_space = self.posterior.forward(img, seg=seg) + + self.prior_latent_space = None + if self.prior is not None: + self.prior_latent_space = self.prior.forward(img) + + if self.use_structure_context: + if seg is None: + raise ValueError("Structure context is enabled, but no segmentation mask provided") + import numpy as np + print(f"imgtype: {img.dtype}") + print(f"imgshape: {img.shape}") + print(f"segtype: {seg.dtype}") + np.save("imgg.npy", img.cpu().numpy()) + img = torch.cat((img, seg), dim=1) + np.save("imgg2.npy", img.cpu().numpy()) + np.save("segg.npy", seg.cpu().numpy()) + + self.unet_features = self.unet.forward(img) + + def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): + """ + Sample a segmentation by reconstructing from a prior sample + and combining this with UNet features + """ + + latent_space = self.prior_latent_space if self.prior is not None else self.posterior_latent_space + z_prior = None + + + testing = False + if testing: + if use_mean: + z_prior = latent_space.base_dist.loc + elif not sample_x_stddev_from_mean is None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + latent_space.base_dist.stddev.device + ) + z_prior = self.prior_latent_space.base_dist.loc + ( + latent_space.base_dist.scale * sample_x_stddev_from_mean + ) + else: + z_prior = latent_space.sample() + self.z_prior_sample = z_prior + else: + pass + z_prior = None + #z_prior = latent_space.rsample() + #self.z_prior_sample = z_prior + + return self.fcomb.forward(self.unet_features, z_prior) + + def reconstruct(self, use_posterior_mean=False, z_posterior=None, sample_x_stddev_from_mean=None): + """ + Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet + feature map + + use_posterior_mean: use posterior_mean instead of sampling z_q + """ + use_posterior_mean = False + sample_x_stddev_from_mean = None + if use_posterior_mean: + z_posterior = self.posterior_latent_space.mean + elif sample_x_stddev_from_mean is not None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + self.posterior_latent_space.base_dist.stddev.device + ) + z_posterior = self.posterior_latent_space.base_dist.loc + ( + self.posterior_latent_space.base_dist.scale * sample_x_stddev_from_mean + ) + else: + pass + z_posterior = None +# if z_posterior is None: +# z_posterior = self.posterior_latent_space.rsample() + return self.fcomb.forward(self.unet_features, z_posterior) + + def kl_divergence(self): + """ + Calculate the KL divergence between the posterior and prior KL(Q||P) + """ + + if self.prior_latent_space is None: + + device = self.posterior_latent_space.base_dist.stddev.device + dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim).to(device)), 1) + kl_div = kl.kl_divergence(self.posterior_latent_space, dist) + else: + kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) + + return kl_div + + def topk_mask(self, score, k): + """Returns a mask for the top-k elements in score.""" + + values, _ = torch.topk(score, 1, axis=1) + _, indices = torch.topk(values, k, axis=0) + return torch.scatter_add( + torch.zeros(score.shape[0]).to(score.device), + 0, + indices.reshape(-1), + torch.ones(score.shape[0]).to(score.device), + ) + + def prepare_mask( + self, + mask, + top_k_percentage, + deterministic, + num_classes, + device, + batch_size, + n_pixels_in_batch, + xe, + ): + if mask is None or mask.sum() == 0: + mask = torch.ones(n_pixels_in_batch) + mask = mask.to(device) + + if top_k_percentage is not None: + + assert 0.0 < top_k_percentage <= 1.0 + k_pixels = int(n_pixels_in_batch * top_k_percentage) + + with torch.no_grad(): + norm_xe = xe / torch.sum(xe) + if deterministic: + score = torch.log(norm_xe) + else: + # TODO Gumbel trick + raise NotImplementedError("Still need to implement Gumbel trick") + + score = score + torch.log(mask.unsqueeze(1).repeat((1, num_classes))) + + top_k_mask = self.topk_mask(score, k_pixels) + top_k_mask = top_k_mask.to(device) + mask = mask * top_k_mask + + else: + mask = torch.reshape(mask, (-1,)) + + mask = mask.unsqueeze(1).repeat((1, num_classes)) + + mask = ( + mask.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + ) + + return mask + + def reconstruction_loss( + self, + segm, + z_posterior=None, + mask=None, + top_k_percentage=None, + deterministic=True, + ): + criterion = torch.nn.BCEWithLogitsLoss(reduction="none") + + # criterion = torch.nn.BCEWithLogitsLoss(size_average=False, reduce=False, reduction=None) + + # if z_posterior is None: + # z_posterior = self.posterior_latent_space.rsample() + z_posterior = 1 + + reconstruction = self.reconstruct(use_posterior_mean=False, z_posterior=z_posterior) + + criterion = torch.nn.BCEWithLogitsLoss(reduction="mean") + + # Take the max of all structure to combine into one big structure to localise + y = segm + pred = reconstruction + np.save("predd.npy", pred.cpu().detach().numpy()) + #y = y.max(axis=1).values + # y = torch.unsqueeze(y, dim=1) + + # Add a background for the localise UNet + # not_y = y.logical_not() + # y = torch.cat((not_y, y), dim=1).float() + np.save("yyy.npy", y.cpu().detach().numpy()) + + loss = criterion(input=pred, target=y) + return loss, None, None + + ##### + num_classes = reconstruction.shape[1] + y_flat = torch.transpose(reconstruction, 1, -1).reshape((-1, num_classes)) + t_flat = torch.transpose(segm, 1, -1).reshape((-1, num_classes)) + n_pixels_in_batch = y_flat.shape[0] + batch_size = segm.shape[0] + + # pos_class_count = t_flat.sum(axis=0) / batch_size + # neg_class_count = torch.logical_not(t_flat).sum(axis=0) / batch_size + # self._pos_weight = ( + # self._pos_weight * 0.5 + (neg_class_count / pos_class_count).clamp(0, 10000) * 0.5 + # ) + + # criterion = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=self._pos_weight) + xe = criterion(input=y_flat, target=t_flat) + xe = xe.reshape((batch_size, -1, num_classes)).transpose(-1, 1).reshape((batch_size, -1)) + + # If multiple masks supplied, compute a loss for each mask + if hasattr(mask, "__iter__"): + ce_sums = [] + ce_means = [] + masks = [] + for this_mask in mask: + this_mask = self.prepare_mask( + this_mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(this_mask * xe, axis=1) + ce_sums.append(torch.mean(ce_sum_per_instance, axis=0)) + ce_means.append(torch.sum(this_mask * xe) / torch.sum(this_mask)) + masks.append(this_mask) + + return ce_sums, ce_means, masks + + mask = self.prepare_mask( + mask, + top_k_percentage, + deterministic, + num_classes, + y_flat.device, + batch_size, + n_pixels_in_batch, + xe, + ) + + ce_sum_per_instance = torch.sum(mask * xe, axis=1) + ce_sum = torch.mean(ce_sum_per_instance, axis=0) + ce_mean = torch.sum(mask * xe) / torch.sum(mask) + + return ce_sum, ce_mean, mask + + def loss(self, segm, mask=None, beta=None): + """ + Calculate the evidence lower bound of the log-likelihood of P(Y|X) + """ + + # z_posterior = self.posterior_latent_space.rsample() + z_posterior = False + # kl_div = torch.mean(self.kl_divergence()) + # kl_div = torch.clamp(kl_div, 0.0, 100.0) + print(f"##### {segm[0][1].max()}") + + top_k_percentage = None + if "top_k_percentage" in self.loss_params: + top_k_percentage = self.loss_params["top_k_percentage"] + + loss_mask = None + contour_threshold = None + if self.loss_type == "geco": + reconstruction_threshold = self.loss_params["kappa"] + if ( + "kappa_contour" in self.loss_params + and self.loss_params["kappa_contour"] is not None + ): + loss_mask = [None, mask] + contour_threshold = self.loss_params["kappa_contour"] + + # Here we use the posterior sample sampled above + rl_sum, rec_loss_mean, _ = self.reconstruction_loss( + segm, + z_posterior=z_posterior, + top_k_percentage=top_k_percentage, + mask=loss_mask, + ) + + # If using contour mask in loss, we get back those in a list. Unpack here. + if contour_threshold: + contour_loss = rl_sum[1] + contour_loss_mean = rec_loss_mean[1] + reconstruction_loss = rl_sum[0] + rec_loss_mean = rec_loss_mean[0] + else: + reconstruction_loss = rl_sum + + if self.loss_type == "elbo": + if beta is None: + beta = self.loss_params["beta"] + + return { + "loss": reconstruction_loss,# + beta * kl_div, + "rec_loss": reconstruction_loss, + "kl_div": 1, + "beta": beta, + } + elif self.loss_type == "geco": + + rec_geco_step_size = self.loss_params["rec_geco_step_size"] + + with torch.no_grad(): + + moving_avg_factor = 0.5 + + rl = rec_loss_mean.detach() + if self._rec_moving_avg is None: + self._rec_moving_avg = rl + else: + self._rec_moving_avg = self._rec_moving_avg * moving_avg_factor + rl * ( + 1 - moving_avg_factor + ) + + rc = self._rec_moving_avg - reconstruction_threshold + + cc = 0 + if contour_threshold: + cl = contour_loss_mean.detach() + if self._contour_moving_avg is None: + self._contour_moving_avg = rl + else: + self._contour_moving_avg = ( + self._contour_moving_avg * moving_avg_factor + + cl * (1 - moving_avg_factor) + ) + + cc = self._contour_moving_avg - contour_threshold + + lambda_lower = self.loss_params["clamp_rec"][0] + lambda_upper = self.loss_params["clamp_rec"][1] + + self._lambda[0] = (torch.exp(rc * rec_geco_step_size) * self._lambda[0]).clamp( + lambda_lower, lambda_upper + ) + # self._lambda[0] = (rc * self._lambda[0]).clamp(lambda_lower, lambda_upper) + if self._lambda[0].isnan(): + self._lambda[0] = lambda_upper + if contour_threshold: + lambda_lower_contour = self.loss_params["clamp_contour"][0] + lambda_upper_contour = self.loss_params["clamp_contour"][1] + + self._lambda[1] = (torch.exp(cc * rec_geco_step_size) * self._lambda[1]).clamp( + lambda_lower_contour, lambda_upper_contour + ) + # self._lambda[1] = (cc * self._lambda[1]).clamp( + # lambda_lower_contour, lambda_upper_contour + # ) + if self._lambda[1].isnan(): + self._lambda[1] = lambda_upper_contour + + # pylint: disable=access-member-before-definition + loss = (self._lambda[0] * reconstruction_loss)# + kl_div + + result = { + "loss": loss, + "rec_loss": reconstruction_loss, + "kl_div": 0, + "lambda_rec": self._lambda[0], + "moving_avg": self._rec_moving_avg, + "reconstruction_threshold": reconstruction_threshold, + "rec_constraint": rc, + } + + if contour_threshold is not None: + result["loss"] = result["loss"] + (self._lambda[1] * contour_loss) + result["contour_loss"] = contour_loss + result["contour_threshold"] = contour_threshold + result["contour_constraint"] = cc + result["moving_avg_contour"] = self._contour_moving_avg + result["lambda_contour"] = self._lambda[1] + + return result + + else: + raise NotImplementedError("Loss must be 'elbo' or 'geco'") diff --git a/platipy/imaging/cnn/pseudo_generator.py b/platipy/imaging/cnn/pseudo_generator.py new file mode 100644 index 00000000..18420322 --- /dev/null +++ b/platipy/imaging/cnn/pseudo_generator.py @@ -0,0 +1,100 @@ +import random +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import SimpleITK as sitk + + +from platipy.imaging.generation.image import insert_sphere +from platipy.imaging import ImageVisualiser + + +def generate_pseudo_data(data_dir="data", cases=5, size=(24, 32, 32), structures=["a", "b", "c"]): + """Generates some Pseudo data to use for testing the CNN code + + Args: + data_dir (str, optional): Directory in which to store pseudo data. Defaults to "data". + cases (int, optional): Number of cases to generate data for. Defaults to 5. + size (tuple, optional): The size of the generated images. Defaults to (24, 32, 32). + structures (list, optional): A list of structure names to generate. Defaults to + ["a", "b", "c"]. + """ + + test_data_directory = Path(data_dir) + + if test_data_directory.exists(): + print("Data directory already exists, won't regenerate") + return + + image_directory = test_data_directory.joinpath("images") + label_directory = test_data_directory.joinpath("labels") + slice_directory = test_data_directory.joinpath("slices") + + image_directory.mkdir(parents=True, exist_ok=True) + label_directory.mkdir(parents=True, exist_ok=True) + slice_directory.mkdir(parents=True, exist_ok=True) + + for case, sphere_rad in enumerate(range(5, 5 + cases)): + + xpos = random.randint(6, 24) + ypos = random.randint(6, 24) + + mask_arr = np.zeros(size) + mask_arr = insert_sphere( + mask_arr, sp_radius=sphere_rad, sp_centre=(int(size[0] / 2), ypos, xpos) + ) + + mask = sitk.GetImageFromArray(mask_arr) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask = sitk.BinaryNot(mask) + + ct = sitk.SignedMaurerDistanceMap(mask) + + ct_arr = sitk.GetArrayFromImage(ct) + ct_arr[ct_arr < -10] = -1000 + ct_arr[ct_arr > 20] = 100 + + ct = sitk.GetImageFromArray(ct_arr) + + sitk.WriteImage(ct, str(image_directory.joinpath(f"{case}.nii.gz"))) + + vis = ImageVisualiser(ct, cut=(int(size[0] / 2), ypos, xpos)) + masks = {} + + for struct_id, structure in enumerate(structures): + + x_shift = y_shift = 0 + if struct_id > 0: + if struct_id % 2 == 0: + x_shift = struct_id + else: + y_shift = struct_id + + for obs_id, obs in enumerate(range(-4, 5, 2)): + obs_rad = sphere_rad + obs + + mask_arr = np.zeros(size) + mask_arr = insert_sphere( + mask_arr, + sp_radius=obs_rad, + sp_centre=(int(size[0] / 2), ypos + y_shift, xpos + x_shift), + ) + + mask = sitk.GetImageFromArray(mask_arr) + mask.CopyInformation(ct) + mask = sitk.Cast(mask, sitk.sitkUInt8) + sitk.WriteImage( + mask, str(label_directory.joinpath(f"{case}_{structure}_{obs_id}.nii.gz")) + ) + + masks[f"struct_{structure}_obs_{obs_id}_{obs_rad}"] = mask + + vis.add_contour(masks) + vis.show() + plt.savefig(slice_directory.joinpath(f"{case}.png")) + plt.close() + + +if __name__ == "__main__": + generate_pseudo_data() diff --git a/platipy/imaging/cnn/sampler.py b/platipy/imaging/cnn/sampler.py new file mode 100644 index 00000000..43cddd22 --- /dev/null +++ b/platipy/imaging/cnn/sampler.py @@ -0,0 +1,20 @@ +import random + +from torch.utils.data import BatchSampler +from torch.utils.data import Sampler + + +class ObserverSampler(Sampler): + def __init__(self, data_source, num_observers): + self.data_source = data_source + self.num_observers = num_observers + + def __iter__(self): + indices = list(range(int(len(self.data_source) / self.num_observers))) + random.shuffle(indices) + for i in indices: + for o in range(self.num_observers): + yield i * self.num_observers + o + + def __len__(self): + return len(self.data_source) diff --git a/platipy/imaging/cnn/test_hpunet.py b/platipy/imaging/cnn/test_hpunet.py new file mode 100644 index 00000000..928f96cd --- /dev/null +++ b/platipy/imaging/cnn/test_hpunet.py @@ -0,0 +1,164 @@ +import torch + +from platipy.imaging.cnn.hierarchical_prob_unet import ( + _HierarchicalCore, + ResBlock, + HierarchicalProbabilisticUnet, +) + +base_channels = 24 +default_channels_per_block = [ + base_channels, + 2 * base_channels, + # 4 * base_channels, + # 8 * base_channels, + # 8 * base_channels, + # 8 * base_channels, + # 8 * base_channels, + # 8 * base_channels, +] +# default_channels_per_block = [ +# base_channels, +# 2 * base_channels, +# 4 * base_channels, +# 8 * base_channels, +# ] + +latent_dims = [8, 6, 2] +latent_dims = [2] + +channels_per_block = default_channels_per_block +down_channels_per_block = [int(i / 2) for i in default_channels_per_block] +c = torch.rand([3, 1, 32, 32]) + +fg = torch.ones(c.shape) +bg = torch.zeros(c.shape) +labels = torch.cat([fg, bg], axis=1) + +hpunet = HierarchicalProbabilisticUnet( + filters_per_layer=channels_per_block, + latent_dims=[1], + loss_type="geco", + loss_params={ + # "top_k_percentage": 0.02, + "top_k_percentage": None, + "deterministic_top_k": False, + "kappa": 0.05, + "decay": 0.99, + "rate": 1e-2, + "clamp_rec": [0.001, 10000], + "beta": 5, + }, +) +output = hpunet.sample(c) +print(output.shape) +output = hpunet.reconstruct(c, labels) +print(output.shape) +loss = hpunet.loss(c, labels) +loss = hpunet.loss(c, labels) +print(loss) + + +_NUM_CLASSES = 2 +_BATCH_SIZE = 2 +_SPATIAL_SHAPE = [32, 32] +_CHANNELS_PER_BLOCK = [5, 7, 9, 11, 13] +_IMAGE_SHAPE = [_BATCH_SIZE] + [1] + _SPATIAL_SHAPE +_BOTTLENECK_SIZE = _SPATIAL_SHAPE[0] // 2 ** (len(_CHANNELS_PER_BLOCK) - 1) +_SEGMENTATION_SHAPE = [_BATCH_SIZE] + [_NUM_CLASSES] + _SPATIAL_SHAPE +_LATENT_DIMS = [3, 2, 1] + + +def _get_placeholders(): + """Returns placeholders for the image and segmentation.""" + img = torch.rand(_IMAGE_SHAPE) + seg = torch.rand(_SEGMENTATION_SHAPE) + return img, seg + + +def test_shape_of_sample(): + hpu_net = HierarchicalProbabilisticUnet( + latent_dims=_LATENT_DIMS, + filters_per_layer=_CHANNELS_PER_BLOCK, + num_classes=_NUM_CLASSES, + ) + img, _ = _get_placeholders() + sample = hpu_net.sample(img) + + assert list(sample.shape) == _SEGMENTATION_SHAPE + + +def test_shape_of_reconstruction(): + hpu_net = HierarchicalProbabilisticUnet( + latent_dims=_LATENT_DIMS, + filters_per_layer=_CHANNELS_PER_BLOCK, + num_classes=_NUM_CLASSES, + ) + img, seg = _get_placeholders() + reconstruction = hpu_net.reconstruct(img, seg) + assert list(reconstruction.shape) == _SEGMENTATION_SHAPE + + +def test_shapes_in_prior(): + hpu_net = HierarchicalProbabilisticUnet( + latent_dims=_LATENT_DIMS, + filters_per_layer=_CHANNELS_PER_BLOCK, + num_classes=_NUM_CLASSES, + ) + img, _ = _get_placeholders() + prior_out = hpu_net._prior(img) + distributions = prior_out["distributions"] + latents = prior_out["used_latents"] + encoder_features = prior_out["encoder_features"] + decoder_features = prior_out["decoder_features"] + + # Test number of latent disctributions. + assert len(distributions) == len(_LATENT_DIMS) + + # Test shapes of latent scales. + for level in range(len(_LATENT_DIMS)): + latent_spatial_shape = _BOTTLENECK_SIZE * 2 ** level + latent_shape = [ + _BATCH_SIZE, + _LATENT_DIMS[level], + latent_spatial_shape, + latent_spatial_shape, + ] + assert list(latents[level].shape) == latent_shape + + # Test encoder shapes. + for level in range(len(_CHANNELS_PER_BLOCK)): + spatial_shape = _SPATIAL_SHAPE[0] // 2 ** level + feature_shape = [_BATCH_SIZE, _CHANNELS_PER_BLOCK[level], spatial_shape, spatial_shape] + + assert list(encoder_features[level].shape) == feature_shape + + # Test decoder shape. + start_level = len(_LATENT_DIMS) + latent_spatial_shape = _BOTTLENECK_SIZE * 2 ** start_level + latent_shape = [ + _BATCH_SIZE, + _CHANNELS_PER_BLOCK[::-1][start_level], + latent_spatial_shape, + latent_spatial_shape, + ] + + assert list(decoder_features.shape) == latent_shape + + +def test_shape_of_kl(): + hpu_net = HierarchicalProbabilisticUnet( + latent_dims=_LATENT_DIMS, + filters_per_layer=_CHANNELS_PER_BLOCK, + num_classes=_NUM_CLASSES, + ) + img, seg = _get_placeholders() + kl_dict = hpu_net.kl(img, seg) + assert len(kl_dict) == len(_LATENT_DIMS) + + +test_shape_of_sample() +test_shape_of_reconstruction() +test_shapes_in_prior() +test_shape_of_kl() +# if __name__ == "__main__": diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py new file mode 100644 index 00000000..7d865297 --- /dev/null +++ b/platipy/imaging/cnn/train.py @@ -0,0 +1,1173 @@ +# 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import tempfile +import json +from argparse import ArgumentParser + +from pathlib import Path +import matplotlib +import SimpleITK as sitk +import numpy as np +from scipy.optimize import linear_sum_assignment + +import comet_ml # pylint: disable=unused-import +from pytorch_lightning.loggers import CometLogger +from torchmetrics import JaccardIndex + +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.callbacks.early_stopping import EarlyStopping + +import matplotlib.pyplot as plt + +from platipy.imaging.cnn.prob_unet import ProbabilisticUnet + +# from platipy.imaging.cnn.hierarchical_prob_unet import HierarchicalProbabilisticUnet +from platipy.imaging.cnn.unet import l2_regularisation +from platipy.imaging.cnn.dataload import UNetDataModule +from platipy.imaging.cnn.dataset import crop_img_using_localise_model +from platipy.imaging.cnn.utils import ( + preprocess_image, + postprocess_mask, + get_metrics, + resample_mask_to_image, +) +from platipy.imaging.cnn.metrics import probabilistic_dice + +from platipy.imaging import ImageVisualiser +from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask + + +class GECOEarlyStopping(EarlyStopping): + def on_validation_end(self, trainer, pl_module): + # Make sure the GECO lambda metrics are below 0.1 before stopping + logs = trainer.callback_metrics + should_consider_early_stop = True + + if "lambda_rec" in logs and logs["lambda_rec"] >= 0.01: + should_consider_early_stop = False + + if "lambda_contour" in logs and logs["lambda_contour"] >= 0.01: + should_consider_early_stop = False + + if should_consider_early_stop: + self._run_early_stopping_check(trainer) + + def on_train_epoch_end(self, trainer, pl_module): + pass + + +class ProbUNet(pl.LightningModule): + def __init__( + self, + **kwargs, + ): + super().__init__() + + self.save_hyperparameters() + + loss_params = None + + if self.hparams.loss_type == "elbo": + loss_params = { + "beta": self.hparams.beta, + } + + if self.hparams.loss_type == "geco": + loss_params = { + "kappa": self.hparams.kappa, + "clamp_rec": self.hparams.clamp_rec, + "clamp_contour": self.hparams.clamp_contour, + "kappa_contour": self.hparams.kappa_contour, + "rec_geco_step_size": self.hparams.rec_geco_step_size, + } + + loss_params["top_k_percentage"] = self.hparams.top_k_percentage + loss_params[ + "contour_loss_lambda_threshold" + ] = self.hparams.contour_loss_lambda_threshold + loss_params["contour_loss_weight"] = self.hparams.contour_loss_weight + + self.use_structure_context = self.hparams.use_structure_context + + if self.hparams.prob_type == "prob": + self.prob_unet = ProbabilisticUnet( + self.hparams.input_channels, + len(self.hparams.structures) + + 1, # Add 1 to num classes for background class + self.hparams.filters_per_layer, + self.hparams.latent_dim, + self.hparams.no_convs_fcomb, + self.hparams.loss_type, + loss_params, + self.hparams.ndims, + dropout_probability=self.hparams.dropout_probability, + use_structure_context=self.use_structure_context, + ) + elif self.hparams.prob_type == "hierarchical": + raise NotImplementedError("Hierarchical Prob UNet current not working...") + # self.prob_unet = HierarchicalProbabilisticUnet( + # input_channels=self.hparams.input_channels, + # num_classes=len(self.hparams.structures), + # filters_per_layer=self.hparams.filters_per_layer, + # down_channels_per_block=self.hparams.down_channels_per_block, + # latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), + # convs_per_block=self.hparams.convs_per_block, + # blocks_per_level=self.hparams.blocks_per_level, + # loss_type=self.hparams.loss_type, + # loss_params=loss_params, + # ndims=self.hparams.ndims, + # ) + + self.validation_directory = None + self.kl_div = None + + self.stddevs = np.linspace(-2, 2, self.hparams.num_observers) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Probabilistic UNet") + parser.add_argument("--prob_type", type=str, default="prob") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--lr_lambda", type=float, default=0.99) + parser.add_argument("--input_channels", type=int, default=1) + parser.add_argument( + "--filters_per_layer", + nargs="+", + type=int, + default=[64 * (2**x) for x in range(5)], + ) + parser.add_argument( + "--down_channels_per_block", nargs="+", type=int, default=None + ) + parser.add_argument("--latent_dim", type=int, default=6) + parser.add_argument("--no_convs_fcomb", type=int, default=4) + parser.add_argument("--convs_per_block", type=int, default=2) + parser.add_argument("--blocks_per_level", type=int, default=1) + parser.add_argument("--loss_type", type=str, default="elbo") + parser.add_argument("--beta", type=float, default=1.0) + parser.add_argument("--kappa", type=float, default=0.02) + parser.add_argument("--kappa_contour", type=float, default=None) + parser.add_argument("--rec_geco_step_size", type=float, default=1e-2) + parser.add_argument("--weight_decay", type=float, default=1e-2) + parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) + parser.add_argument( + "--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3] + ) + parser.add_argument("--top_k_percentage", type=float, default=None) + parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) + parser.add_argument( + "--contour_loss_weight", type=float, default=0.0 + ) # no longer used + parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used + parser.add_argument("--dropout_probability", type=float, default=0.0) + parser.add_argument("--use_structure_context", type=int, default=0) + + return parent_parser + + def forward(self, x): + self.prob_unet.forward(x, None, False) + return x + + def configure_optimizers(self): + #optimizer = torch.optim.Adam( + # self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + #) + #lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) +# scheduler = torch.optim.lr_scheduler.LambdaLR( +# optimizer, lr_lambda=[lr_lambda_unet] +# ) + + #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, 50, eta_min=1e-5, verbose=True + #) + + #scheduler = torch.optim.lr_scheduler.CyclicLR( + # optimizer, + # base_lr=self.hparams.learning_rate / 10, + # max_lr=self.hparams.learning_rate * 10, + # step_size_up=50, + # mode="exp_range", + # gamma=0.9999, + # cycle_momentum=False + #) + + #return [optimizer], [scheduler] + params = [ + { + "params": self.prob_unet.unet.parameters(), + "weight_decay": self.hparams.weight_decay, + "lr": 1e-5, + } + ] + + if self.prob_unet.prior is not None: + param_list =[ + self.prob_unet.prior.parameters(), + self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ] + else: + param_list =[ + self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ] + for m in param_list: + params += [ + {"params": m, "weight_decay": self.hparams.weight_decay, "lr": 1e-5} + ] + + optimizer = torch.optim.Adam(params) + + lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) + lr_lambda_prob = lambda epoch: 0.99 ** (epoch) + + # max_epochs = self.hparams.max_epochs + # lr_lambda = lambda x: np.interp(((np.sin(x/(max_epochs/8)) * np.sin(x/(max_epochs/4)))), np.array([-1,0,1]), np.array([0.1,1,10])) + + # scheduler = torch.optim.lr_scheduler.LambdaLR( + # optimizer, lr_lambda=[lr_lambda_unet, lr_lambda_prob, lr_lambda_prob, lr_lambda_prob] + # ) + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=self.hparams.learning_rate, + max_lr=self.hparams.learning_rate * 10, + step_size_up=50, + mode="exp_range", + gamma=0.99, + cycle_momentum=False + ) + # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, 50, eta_min=1e-6, verbose=True + # ) + + return [optimizer], [scheduler] + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + "max", + patience=20, + threshold=0.1e-2, + factor=0.75 + # optimizer, "max", patience=200, threshold=0.75, factor=0.5 + ), + "monitor": "probabilisticDice", + }, + } + + def infer( + self, + img, + context_map=None, + seg=None, + num_samples=1, + sample_strategy="mean", + latent_dim=True, + spaced_range=[-1.5, 1.5], + preprocess=True, + return_latent_space=False + ): + # sample strategy in "mean", "random", "spaced" + + if not hasattr(latent_dim, "__iter__"): + latent_dim = [ + latent_dim, + ] * self.hparams.latent_dim + + if sample_strategy == "mean": + samples = [ + { + "name": "mean", + "std_dev_from_mean": torch.Tensor([0.0] * len(latent_dim)).to( + self.device + ), + "preds": [], + } + ] + elif sample_strategy == "random": + samples = [ + { + "name": f"random_{i}", + "std_dev_from_mean": torch.Tensor( + [ + np.random.normal(0, 1.0, 1)[0] if d else 0.0 + for d in latent_dim + ] + ).to(self.device), + "preds": [], + } + for i in range(num_samples) + ] + elif sample_strategy == "spaced": + if self.hparams.prob_type == "hierarchical": + latent_dim = [True] * (len(self.hparams.filters_per_layer) - 1) + samples = [ + { + "name": f"spaced_{s:.2f}", + "std_dev_from_mean": torch.Tensor( + [s if d else 0.0 for d in latent_dim] + ).to(self.device), + "preds": [], + } + for s in np.linspace(spaced_range[0], spaced_range[1], num_samples) + ] + + with torch.no_grad(): + if preprocess: + if self.hparams.crop_using_localise_model: + localise_path = self.hparams.crop_using_localise_model.format( + fold=self.hparams.fold + ) + img = crop_img_using_localise_model( + img, + localise_path, + spacing=self.hparams.spacing, + crop_to_grid_size=self.hparams.localise_voxel_grid_size, + context_seg=seg + ) + else: + img = preprocess_image( + img, + spacing=self.hparams.spacing, + crop_to_grid_size_xy=self.hparams.crop_to_grid_size, + intensity_scaling=self.hparams.intensity_scaling, + intensity_window=self.hparams.intensity_window, + ) + + + + img_arr = sitk.GetArrayFromImage(img) + + if context_map is not None: + context_map = resample_mask_to_image(img, context_map) + cmap_arr = sitk.GetArrayFromImage(img) + + if seg is not None: + seg = resample_mask_to_image(img, seg) + seg_arr = sitk.GetArrayFromImage(seg) + + if self.hparams.ndims == 2: + slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] + + if context_map is not None: + cmap_slices = [cmap_arr[z, :, :] for z in range(cmap_arr.shape[0])] + + if seg is not None: + seg_slices = [seg_arr[z, :, :] for z in range(seg_arr.shape[0])] + else: + slices = [img_arr] + if context_map is not None: + cmap_slices = [cmap_arr] + + if seg is not None: + seg_slices = [seg_arr] + + for idx, i in enumerate(slices): + x = torch.Tensor(i).to(self.device) + x = x.unsqueeze(0) + x = x.unsqueeze(0) + + if context_map is not None: + c = torch.Tensor(cmap_slices[idx]).to(self.device) + c = c.unsqueeze(0) + c = c.unsqueeze(0) + + x = torch.cat((x, c), dim=1) + + if seg is not None: + s = torch.Tensor(seg_slices[idx]).to(self.device) + s = s.unsqueeze(0) + s = s.unsqueeze(0) + + # Add in background channel + not_s = 1 - s.max(axis=1).values + not_s = torch.unsqueeze(not_s, dim=1) + s = torch.cat((not_s, s), dim=1).float() + + if self.hparams.prob_type == "prob": + if seg is not None: + self.prob_unet.forward(x, cseg=s) + else: + self.prob_unet.forward(x) + + if return_latent_space: + return self.prob_unet.prior_latent_space + + for sample in samples: + if self.hparams.prob_type == "prob": + if sample["name"] == "mean": +# if seg is None: + y = self.prob_unet.sample(testing=True, use_mean=True) +# else: +# y = self.prob_unet.reconstruct(use_posterior_mean=True) + else: +# if seg is None: + y = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + # else: + # y = self.prob_unet.reconstruct( + # use_posterior_mean=False, + # sample_x_stddev_from_mean=sample["std_dev_from_mean"], + # ) + + # else: + # if sample["name"] == "mean": + # y = self.prob_unet.sample(x, mean=True) + # else: + # y = self.prob_unet.sample( + # x, + # mean=True, + # std_devs_from_mean=sample["std_dev_from_mean"], + # ) + + y = y.squeeze(0) + # y = np.argmax(y.cpu().detach().numpy(), axis=0) + y = torch.sigmoid(y) + sample["preds"].append(y.cpu().detach().numpy()) + + result = {} + for sample in samples: + pred_arr = sample["preds"][0] + + if self.hparams.ndims == 2: + pred_arr = np.expand_dims(pred_arr, 1) + if len(sample["preds"]) > 1: + pred_arr = np.stack(sample["preds"], axis=1) + + result[sample["name"]] = {} + + for idx, structure in enumerate(self.hparams.structures): + pred = sitk.GetImageFromArray(pred_arr[idx + 1]) # Skip the background + pred = pred > 0.5 # Threshold softmax at 0.5 + pred = sitk.Cast(pred, sitk.sitkUInt8) + + pred.CopyInformation(img) + pred = postprocess_mask(pred) + pred = sitk.Resample( + pred, img, sitk.Transform(), sitk.sitkNearestNeighbor + ) + + result[sample["name"]][structure] = pred + + return result + + def validate( + self, + img, + manual_observers, + samples, + mean, + matching_type="best", + window=[-0.5, 1.0], + ): + metrics = {"DSC": "max", "HD": "min", "ASD": "min"} + result = {} + + contour_cmaps = ["RdPu", "YlOrRd", "GnBu", "OrRd", "YlGn", "YlGnBu"] + structures = self.hparams.structures + + try: + cut = get_com(mean["mean"][structures[0]]) + except ValueError: + cut = [int(i / 2) for i in img.GetSize()][::-1] + + vis = ImageVisualiser(img, cut=cut, figure_size_in=16, window=window) + + mean_contours = {} + for idx, structure in enumerate(structures): + color_map = matplotlib.colormaps.get_cmap( + contour_cmaps[idx % len(structures)] + ) + mean_contours[f"mean_{structure}"] = mean["mean"][structure] + + vis.add_contour( + mean_contours, color=color_map(0.35), linewidth=3, show_legend=False + ) + + manual_color = color_map(0.9) + + manual_observers_struct = { + f"{man_struct}_{structure}": manual_observers[man_struct][structure] + for man_struct in manual_observers + } + +# vis.add_contour( +# manual_observers_struct, +# color=manual_color, +# linewidth=0.5, +# show_legend=False, +# ) + + intersection_mask = get_intersection_mask(manual_observers_struct) + union_mask = get_union_mask(manual_observers_struct) + + vis.add_contour( + intersection_mask, + name=f"intersection_{structure}", + color=manual_color, + linewidth=3, + ) + vis.add_contour( + union_mask, name=f"union_{structure}", color=manual_color, linewidth=3 + ) + + samples_struct = { + f"{sample_struct}_{structure}": samples[sample_struct][structure] + for sample_struct in samples + } + vis.add_contour( + samples_struct, + linewidth=1.5, + color={ + s: c + for s, c in zip( + samples_struct, + color_map(np.linspace(0.1, 0.7, len(samples_struct))), + ) + }, + ) + + # vis.set_limits_from_label(union_mask, expansion=30) + + sim = { + k: np.zeros((len(samples_struct), len(manual_observers_struct))) + for k in metrics + } + msim = { + k: np.zeros((len(samples_struct), len(manual_observers_struct))) + for k in metrics + } + for sid, samp in enumerate(samples_struct): + for oid, obs in enumerate(manual_observers_struct): + sample_metrics = get_metrics( + manual_observers_struct[obs], samples_struct[samp] + ) + mean_metrics = get_metrics( + manual_observers_struct[obs], mean_contours[f"mean_{structure}"] + ) + + for k in sample_metrics: + sim[k][sid, oid] = sample_metrics[k] + msim[k][sid, oid] = mean_metrics[k] + + result[f"probnet_{structure}"] = {k: [] for k in metrics} + result[f"unet_{structure}"] = {k: [] for k in metrics} + for k in sim: + val = sim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + prob_unet_mean = sim[k][row_idx, col_idx].mean() + else: + if metrics[k] == "max": + prob_unet_mean = val.max() + else: + prob_unet_mean = val.min() + result[f"probnet_{structure}"][k].append(prob_unet_mean) + + val = msim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + unet_mean = msim[k][row_idx, col_idx].mean() + else: + if metrics[k] == "max": + unet_mean = val.max() + else: + unet_mean = val.min() + result[f"unet_{structure}"][k].append(unet_mean) + + fig = vis.show() + + return result, fig + + def training_step(self, batch, _): + x, c, y, cy, m, _ = batch + + # Add background layer for one-hot encoding + not_y = 1 - y.max(axis=1).values + not_y = torch.unsqueeze(not_y, dim=1) + y = torch.cat((not_y, y), dim=1).float() + + not_cy = 1 - cy.max(axis=1).values + not_cy = torch.unsqueeze(not_cy, dim=1) + cy = torch.cat((not_cy, cy), dim=1).float() + + # Concat context map to image if we have one + if c.numel() > 0: + x = torch.cat((x, c), dim=1).float() + + print(f"{y.shape} {cy.shape}") + + # self.prob_unet.forward(x, y, training=True) + if self.hparams.prob_type == "prob": + self.prob_unet.forward(x, y, cy, training=True) + # else: + # self.prob_unet.forward(x, y) + + if self.hparams.prob_type == "prob": + loss = self.prob_unet.loss(y, mask=m) + # else: + # loss = self.prob_unet.loss(x, y, mask=m) + + training_loss = loss["loss"] + + # Using weight decay instead + # if self.hparams.prob_type == "prob": + # reg_loss = ( + # l2_regularisation(self.prob_unet.posterior) + # + l2_regularisation(self.prob_unet.prior) + # + l2_regularisation(self.prob_unet.fcomb.layers) + # ) + # training_loss = training_loss + 1e-5 * reg_loss + self.log( + "training_loss", + training_loss.detach(), + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) + + self.kl_div = loss["kl_div"].detach().cpu() + + for k in loss: + if k == "loss": + continue + self.log( + k, + loss[k].detach() if isinstance(loss[k], torch.Tensor) else loss[k], + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) + return training_loss + + def validation_step(self, batch, _): + if self.validation_directory is None: + self.validation_directory = Path(tempfile.mkdtemp()) + + n = self.hparams.num_observers + m = self.hparams.num_observers + + with torch.set_grad_enabled(False): + x, c, y, cy, _, info = batch + + # Save off slices/volumes for analysis of entire structure in end of validation step + for s in range(y.shape[0]): + img_file = self.validation_directory.joinpath( + f"img_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(img_file, x[s].squeeze(0).cpu().numpy()) + + if c.numel() > 0: + cmap_file = self.validation_directory.joinpath( + f"cmap_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(cmap_file, c[s].squeeze(0).cpu().numpy()) + + mask_file = self.validation_directory.joinpath( + f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + ) + np.save(mask_file, y[s].cpu().numpy()) + + # Image (and context map) will be same for all in batch + x = x[0].unsqueeze(0) + if c.numel() > 0: + c = c[0].unsqueeze(0) + if self.hparams.ndims == 2: + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0]), axis="z") + else: + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0, 0])) + + if self.hparams.ndims == 2: + x = x.repeat(m, 1, 1, 1) + + if c.numel() > 0: + c = c.repeat(m, 1, 1, 1) + else: + x = x.repeat(m, 1, 1, 1, 1) + + if c.numel() > 0: + c = c.repeat(m, 1, 1, 1, 1) + + if c.numel() > 0: + x = torch.cat((x, c), dim=1) + + seg = None + if self.use_structure_context: + not_y = 1 - y.max(axis=1).values + not_y = torch.unsqueeze(not_y, dim=1) + seg = torch.cat((not_y, y), dim=1).float() + + self.prob_unet.forward(x, cseg=seg) + # loss = self.prob_unet.loss(seg) + # print(f"VAL LOSS: {loss}") + + py = self.prob_unet.sample(testing=True) + py = py.to("cpu") + + pred_y = torch.zeros(py[:, 0, :].shape).int() + for b in range(py.shape[0]): + pred_y[b] = py[b, :].argmax(0).int() + + y = y.squeeze(1) + y = y.int() + y = y.to("cpu") + + + cy = cy.squeeze(1) + cy = cy.int() + cy = cy.to("cpu") + + # TODO Make this work for multi class + # Intersection over Union (also known as Jaccard Index) + jaccard = JaccardIndex(num_classes=2) + term_1 = 0 + for i in range(n): + for j in range(m): + if pred_y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], y[j]) + term_1 += 1 - iou + term_1 = term_1 * (2 / (m * n)) + + term_2 = 0 + for i in range(n): + for j in range(n): + if pred_y[i].sum() + pred_y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], pred_y[j]) + term_2 += 1 - iou + term_2 = term_2 * (1 / (n * n)) + + term_3 = 0 + for i in range(m): + for j in range(m): + if y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(y[i], y[j]) + term_3 += 1 - iou + term_3 = term_3 * (1 / (m * m)) + + D_ged = term_1 - term_2 - term_3 + + contours = {} + contour_colors = {} + for o in range(n): + obs_y = y[o].float() + if self.hparams.ndims == 2: + obs_y = obs_y.unsqueeze(0) + contours[f"obs_{o}"] = sitk.GetImageFromArray(obs_y) + contour_colors[f"obs_{o}"] = (0.3, 0.6, 0.3) + for mm in range(m): + samp_pred = pred_y[mm].float() + if self.hparams.ndims == 2: + samp_pred = samp_pred.unsqueeze(0) + contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) + contour_colors[f"sample_{mm}"] = (0.1, 0.1, 0.8) + + if self.use_structure_context: + for o in range(n): + obs_y = cy[o].float() + if self.hparams.ndims == 2: + obs_y = obs_y.unsqueeze(0) + contours[f"compobs_{o}"] = sitk.GetImageFromArray(obs_y) + contour_colors[f"compobs_{o}"] = (0.6, 0.3, 0.3) + + vis.add_contour(contours, color=contour_colors) + vis.show() + + figure_path = f"ged_{info['z'][s]}.png" + plt.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + self.log("GED", D_ged) + + return info + + def validation_epoch_end(self, validation_step_outputs): + cases = {} + for info in validation_step_outputs: + for case, z, observer in zip(info["case"], info["z"], info["observer"]): + if not case in cases: + cases[case] = {"slices": z.item(), "observers": [observer]} + else: + if z.item() > cases[case]["slices"]: + cases[case]["slices"] = z.item() + if not observer in cases[case]["observers"]: + cases[case]["observers"].append(observer) + + metrics = ["DSC", "HD", "ASD"] + computed_metrics = { + **{ + f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures + }, + **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + } + + if len(cases) == 0: + return + + prob_surface_dice = 0 + prob_dice = 0 + + for case in cases: + img_arrs = [] + cmap_arrs = [] + cmap_arr = None + slices = [] + + if self.hparams.ndims == 2: + for z in range(cases[case]["slices"] + 1): + img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + if img_file.exists(): + img_arrs.append(np.load(img_file)) + slices.append(z) + + cmap_file = self.validation_directory.joinpath( + f"cmap_{case}_{z}.npy" + ) + if cmap_file.exists(): + cmap_arrs.append(np.load(cmap_file)) + + img_arr = np.stack(img_arrs) + + if len(cmap_arrs) > 0: + cmap_arr = np.stack(cmap_arr) + + else: + img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") + img_arr = np.load(img_file) + + cmap_file = self.validation_directory.joinpath(f"cmap_{case}_0.npy") + if cmap_file.exists(): + cmap_arr = np.load(cmap_file) + + img = sitk.GetImageFromArray(img_arr) + img.SetSpacing(self.hparams.spacing) + + observers = {} + for _, observer in enumerate(cases[case]["observers"]): + if self.hparams.ndims == 2: + mask_arrs = [] + for z in slices: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + + mask_arrs.append(np.load(mask_file)) + + mask_arr = np.stack(mask_arrs, axis=1) + + else: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + mask_arr = np.load(mask_file) + + observers[f"manual_{observer}"] = {} + for idx, structure in enumerate(self.hparams.structures): + mask = sitk.GetImageFromArray(mask_arr[idx]) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask.CopyInformation(img) + observers[f"manual_{observer}"][structure] = mask + + context_map = None + if cmap_arr is not None: + context_map = sitk.GetImageFromArray(cmap_arr) + context_map.SetSpacing(self.hparams.spacing) + + seg = None + if self.use_structure_context: + # Staple the man observers to pass in as context seg + masks = [] + for man_obs in observers: + masks.append(observers[man_obs][structure]) + + stapled = sitk.STAPLE(masks) + stapled = stapled > 0.5 + stapled = sitk.Cast(stapled, sitk.sitkUInt8) + seg = stapled + +# try: + mean = self.infer( + img, + context_map=context_map, + seg=seg, + num_samples=1, + sample_strategy="mean", + preprocess=False, + ) + samples = self.infer( + img, + context_map=context_map, + seg=seg, + sample_strategy="spaced", + num_samples=11, + spaced_range=[-2, 2], + preprocess=False, + ) +# except Exception as e: +# print(f"ERROR DURING VALIDATION INFERENCE: {e}") +# return + + + # try: + result, fig = self.validate( + img, observers, samples, mean, matching_type="best" + ) + # except Exception as e: + # print(f"ERROR DURING VALIDATION VALIDATE: {e}") + # return + + figure_path = f"valid_{case}.png" + fig.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + for t in result: + for m in metrics: + computed_metrics[f"{t}_{m}"] += result[t][m] + + # Compute the probabilistic (surface) dice + for idx, structure in enumerate(self.hparams.structures): + gt_labels = [] + for _, observer in enumerate(cases[case]["observers"]): + gt_labels.append(observers[f"manual_{observer}"][structure]) + + sample_labels = [] + for rk in samples: + sample_labels.append(samples[rk][structure]) + + prob_dice += probabilistic_dice( + gt_labels, sample_labels, dsc_type="dsc" + ) + prob_surface_dice += probabilistic_dice( + gt_labels, sample_labels, dsc_type="sdsc", tau=3 + ) + + prob_dice = prob_dice / len(cases) + if np.isnan(prob_dice): + prob_dice = 0 + self.log( + "probabilisticDice", + prob_dice, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + prob_surface_dice = prob_surface_dice / len(cases) + if np.isnan(prob_surface_dice): + prob_surface_dice = 0 + self.log( + "probabilisticSurfaceDice", + prob_surface_dice, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.kl_div: + p = u = 0 + for s in self.hparams.structures: + p += np.array(computed_metrics[f"probnet_{s}_DSC"]).mean() + u += np.array(computed_metrics[f"unet_{s}_DSC"]).mean() + + p /= len(self.hparams.structures) + u /= len(self.hparams.structures) + computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div + + for cm in computed_metrics: + self.log( + cm, + np.array(computed_metrics[cm]).mean(), + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + # shutil.rmtree(self.validation_directory) + + +def main(args, config_json_path=None): + pl.seed_everything(args.seed, workers=True) + + args.working_dir = Path(args.working_dir) + args.working_dir = args.working_dir.joinpath(args.experiment) + # args.default_root_dir = str(args.working_dir) + args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") + args.default_root_dir = str(args.fold_dir) + args.accumulate_grad_batches = {0: 5, 5: 10, 10: 15} + # args.precision = 16 + + comet_api_key = None + comet_workspace = None + comet_project = None + + if args.comet_api_key: + comet_api_key = args.comet_api_key + comet_workspace = args.comet_workspace + comet_project = args.comet_project + + if comet_api_key is None: + if "COMET_API_KEY" in os.environ: + comet_api_key = os.environ["COMET_API_KEY"] + if "COMET_WORKSPACE" in os.environ: + comet_workspace = os.environ["COMET_WORKSPACE"] + if "COMET_PROJECT" in os.environ: + comet_project = os.environ["COMET_PROJECT"] + + if comet_api_key is not None: + comet_logger = CometLogger( + api_key=comet_api_key, + workspace=comet_workspace, + project_name=comet_project, + experiment_name=args.experiment, + save_dir=args.working_dir, + offline=args.offline, + ) + if config_json_path: + comet_logger.experiment.log_code(config_json_path) + + dict_args = vars(args) + + data_module = UNetDataModule(**dict_args) + + prob_unet = ProbUNet(**dict_args) + + if args.resume_from is not None: + trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) + else: + trainer = pl.Trainer.from_argparse_args(args) + + if comet_api_key is not None: + trainer.logger = comet_logger + + lr_monitor = LearningRateMonitor(logging_interval="step") + trainer.callbacks.append(lr_monitor) + + # Save the best model + if args.checkpoint_var: + checkpoint_callback = ModelCheckpoint( + monitor=args.checkpoint_var, + dirpath=args.default_root_dir, + filename="probunet-{epoch:02d}-{" + args.checkpoint_var + ":.2f}", + save_top_k=1, + mode=args.checkpoint_mode, + ) + trainer.callbacks.append(checkpoint_callback) + + if args.early_stopping_var: + early_stop_callback = GECOEarlyStopping( + monitor=args.early_stopping_var, + min_delta=args.early_stopping_min_delta, + patience=args.early_stopping_patience, + verbose=True, + mode=args.early_stopping_mode, + ) + trainer.callbacks.append(early_stop_callback) + + trainer.fit(prob_unet, data_module) + + +def parse_config_file(config_json_path, args): + with open(config_json_path, "r") as f: + params = json.load(f) + for key in params: + args.append(f"--{key}") + + if isinstance(params[key], list): + for list_val in params[key]: + args.append(str(list_val)) + else: + args.append(str(params[key])) + + return args + + +if __name__ == "__main__": + args = None + config_json_path = None + if len(sys.argv) == 2: + # Check if JSON file parsed, if so read arguments from there... + if sys.argv[-1].endswith(".json"): + config_json_path = sys.argv[-1] + args = parse_config_file(config_json_path, []) + + arg_parser = ArgumentParser() + arg_parser = ProbUNet.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument( + "--seed", type=int, default=42, help="an integer to use as seed" + ) + arg_parser.add_argument( + "--experiment", type=str, default="default", help="Name of experiment" + ) + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", nargs="+", type=float, default=[1, 1, 1]) + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + arg_parser.add_argument("--early_stopping_var", type=str, default=None) + arg_parser.add_argument("--early_stopping_min_delta", type=float, default=0.01) + arg_parser.add_argument("--early_stopping_patience", type=int, default=50) + arg_parser.add_argument("--early_stopping_mode", type=str, default="max") + arg_parser.add_argument("--checkpoint_var", type=str, default=None) + arg_parser.add_argument("--checkpoint_mode", type=str, default="max") + + parsed_args = arg_parser.parse_args(args) + + # Check if config arg parsed, if so take over values and reparse + if parsed_args.config: + print("parseing args") + args = parse_config_file(parsed_args.config, sys.argv[1:]) + parsed_args = arg_parser.parse_args(args) + + main(parsed_args) diff --git a/platipy/imaging/cnn/train_debug.py b/platipy/imaging/cnn/train_debug.py new file mode 100644 index 00000000..d87d0468 --- /dev/null +++ b/platipy/imaging/cnn/train_debug.py @@ -0,0 +1,1134 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import tempfile +import json +from argparse import ArgumentParser + +from pathlib import Path +import matplotlib +import SimpleITK as sitk +import numpy as np +from scipy.optimize import linear_sum_assignment + +import comet_ml # pylint: disable=unused-import +from pytorch_lightning.loggers import CometLogger +from torchmetrics import JaccardIndex + +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.callbacks.early_stopping import EarlyStopping + +import matplotlib.pyplot as plt + +from platipy.imaging.cnn.prob_unet import ProbabilisticUnet + +# from platipy.imaging.cnn.hierarchical_prob_unet import HierarchicalProbabilisticUnet +from platipy.imaging.cnn.unet import l2_regularisation +from platipy.imaging.cnn.dataload import UNetDataModule +from platipy.imaging.cnn.dataset import crop_img_using_localise_model +from platipy.imaging.cnn.utils import ( + preprocess_image, + postprocess_mask, + get_metrics, + resample_mask_to_image, +) +from platipy.imaging.cnn.metrics import probabilistic_dice + +from platipy.imaging import ImageVisualiser +from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask + + +class GECOEarlyStopping(EarlyStopping): + def on_validation_end(self, trainer, pl_module): + # Make sure the GECO lambda metrics are below 0.1 before stopping + logs = trainer.callback_metrics + should_consider_early_stop = True + + if "lambda_rec" in logs and logs["lambda_rec"] >= 0.01: + should_consider_early_stop = False + + if "lambda_contour" in logs and logs["lambda_contour"] >= 0.01: + should_consider_early_stop = False + + if should_consider_early_stop: + self._run_early_stopping_check(trainer) + + def on_train_epoch_end(self, trainer, pl_module): + pass + + +class ProbUNet(pl.LightningModule): + def __init__( + self, + **kwargs, + ): + super().__init__() + + self.save_hyperparameters() + + loss_params = None + + if self.hparams.loss_type == "elbo": + loss_params = { + "beta": self.hparams.beta, + } + + if self.hparams.loss_type == "geco": + loss_params = { + "kappa": self.hparams.kappa, + "clamp_rec": self.hparams.clamp_rec, + "clamp_contour": self.hparams.clamp_contour, + "kappa_contour": self.hparams.kappa_contour, + "rec_geco_step_size": self.hparams.rec_geco_step_size, + } + + loss_params["top_k_percentage"] = self.hparams.top_k_percentage + loss_params[ + "contour_loss_lambda_threshold" + ] = self.hparams.contour_loss_lambda_threshold + loss_params["contour_loss_weight"] = self.hparams.contour_loss_weight + + self.use_structure_context = self.hparams.use_structure_context + + if self.hparams.prob_type == "prob": + self.prob_unet = ProbabilisticUnet( + self.hparams.input_channels, + len(self.hparams.structures) + + 1, # Add 1 to num classes for background class + self.hparams.filters_per_layer, + self.hparams.latent_dim, + self.hparams.no_convs_fcomb, + self.hparams.loss_type, + loss_params, + self.hparams.ndims, + dropout_probability=self.hparams.dropout_probability, + use_structure_context=self.use_structure_context, + ) + elif self.hparams.prob_type == "hierarchical": + raise NotImplementedError("Hierarchical Prob UNet current not working...") + # self.prob_unet = HierarchicalProbabilisticUnet( + # input_channels=self.hparams.input_channels, + # num_classes=len(self.hparams.structures), + # filters_per_layer=self.hparams.filters_per_layer, + # down_channels_per_block=self.hparams.down_channels_per_block, + # latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), + # convs_per_block=self.hparams.convs_per_block, + # blocks_per_level=self.hparams.blocks_per_level, + # loss_type=self.hparams.loss_type, + # loss_params=loss_params, + # ndims=self.hparams.ndims, + # ) + + self.validation_directory = None + self.kl_div = None + + self.stddevs = np.linspace(-3, 3, self.hparams.num_observers) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Probabilistic UNet") + parser.add_argument("--prob_type", type=str, default="prob") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--lr_lambda", type=float, default=0.99) + parser.add_argument("--input_channels", type=int, default=1) + parser.add_argument( + "--filters_per_layer", + nargs="+", + type=int, + default=[64 * (2**x) for x in range(5)], + ) + parser.add_argument( + "--down_channels_per_block", nargs="+", type=int, default=None + ) + parser.add_argument("--latent_dim", type=int, default=6) + parser.add_argument("--no_convs_fcomb", type=int, default=4) + parser.add_argument("--convs_per_block", type=int, default=2) + parser.add_argument("--blocks_per_level", type=int, default=1) + parser.add_argument("--loss_type", type=str, default="elbo") + parser.add_argument("--beta", type=float, default=1.0) + parser.add_argument("--kappa", type=float, default=0.02) + parser.add_argument("--kappa_contour", type=float, default=None) + parser.add_argument("--rec_geco_step_size", type=float, default=1e-2) + parser.add_argument("--weight_decay", type=float, default=1e-2) + parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) + parser.add_argument( + "--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3] + ) + parser.add_argument("--top_k_percentage", type=float, default=None) + parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) + parser.add_argument( + "--contour_loss_weight", type=float, default=0.0 + ) # no longer used + parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used + parser.add_argument("--dropout_probability", type=float, default=0.0) + parser.add_argument("--use_structure_context", type=int, default=0) + + return parent_parser + + def forward(self, x): + self.prob_unet.forward(x, None, False) + return x + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + ) + + return optimizer + params = [ + { + "params": self.prob_unet.unet.parameters(), + "weight_decay": self.hparams.weight_decay, + "lr": 1e-4, + } + ] + + if self.prob_unet.prior is not None: + param_list =[ + self.prob_unet.prior.parameters(), + self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ] + else: + param_list =[ +# self.prob_unet.posterior.parameters(), + self.prob_unet.fcomb.parameters(), + ] + for m in param_list: + params += [ + {"params": m, "weight_decay": self.hparams.weight_decay, "lr": 1e-4} + ] + + optimizer = torch.optim.Adam(params) + + lr_lambda_unet = lambda epoch: self.hparams.lr_lambda ** (epoch) + lr_lambda_prob = lambda epoch: 0.99 ** (epoch) + + # max_epochs = self.hparams.max_epochs + # lr_lambda = lambda x: np.interp(((np.sin(x/(max_epochs/8)) * np.sin(x/(max_epochs/4)))), np.array([-1,0,1]), np.array([0.1,1,10])) + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=[lr_lambda_unet, lr_lambda_prob] + ) + # scheduler = torch.optim.lr_scheduler.CyclicLR( + # optimizer, + # base_lr=self.hparams.learning_rate / 10, + # max_lr=self.hparams.learning_rate, + # step_size_up=20, + # mode="exp_range", + # gamma=0.999, + # cycle_momentum=False + # ) + #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + # optimizer, 50, eta_min=1e-6, verbose=True + #) + + return [optimizer], [scheduler] + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + "max", + patience=20, + threshold=0.1e-2, + factor=0.75 + # optimizer, "max", patience=200, threshold=0.75, factor=0.5 + ), + "monitor": "probabilisticDice", + }, + } + + def infer( + self, + img, + context_map=None, + seg=None, + num_samples=1, + sample_strategy="mean", + latent_dim=True, + spaced_range=[-1.5, 1.5], + preprocess=True, + ): + # sample strategy in "mean", "random", "spaced" + + if not hasattr(latent_dim, "__iter__"): + latent_dim = [ + latent_dim, + ] * self.hparams.latent_dim + + if sample_strategy == "mean": + samples = [ + { + "name": "mean", + "std_dev_from_mean": torch.Tensor([0.0] * len(latent_dim)).to( + self.device + ), + "preds": [], + } + ] + elif sample_strategy == "random": + samples = [ + { + "name": f"random_{i}", + "std_dev_from_mean": torch.Tensor( + [ + np.random.normal(0, 1.0, 1)[0] if d else 0.0 + for d in latent_dim + ] + ).to(self.device), + "preds": [], + } + for i in range(num_samples) + ] + elif sample_strategy == "spaced": + if self.hparams.prob_type == "hierarchical": + latent_dim = [True] * (len(self.hparams.filters_per_layer) - 1) + samples = [ + { + "name": f"spaced_{s:.2f}", + "std_dev_from_mean": torch.Tensor( + [s if d else 0.0 for d in latent_dim] + ).to(self.device), + "preds": [], + } + for s in np.linspace(spaced_range[0], spaced_range[1], num_samples) + ] + + with torch.no_grad(): + if preprocess: + if self.hparams.crop_using_localise_model: + localise_path = self.hparams.crop_using_localise_model.format( + fold=self.hparams.fold + ) + img = crop_img_using_localise_model( + img, + localise_path, + spacing=self.hparams.spacing, + crop_to_grid_size=self.hparams.localise_voxel_grid_size, + ) + else: + img = preprocess_image( + img, + spacing=self.hparams.spacing, + crop_to_grid_size_xy=self.hparams.crop_to_grid_size, + intensity_scaling=self.hparams.intensity_scaling, + intensity_window=self.hparams.intensity_window, + ) + + + + img_arr = sitk.GetArrayFromImage(img) + + if context_map is not None: + context_map = resample_mask_to_image(img, context_map) + cmap_arr = sitk.GetArrayFromImage(img) + + if seg is not None: + seg = resample_mask_to_image(img, seg) + seg_arr = sitk.GetArrayFromImage(seg) + + if self.hparams.ndims == 2: + slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] + + if context_map is not None: + cmap_slices = [cmap_arr[z, :, :] for z in range(cmap_arr.shape[0])] + + if seg is not None: + seg_slices = [seg_arr[z, :, :] for z in range(seg_arr.shape[0])] + else: + slices = [img_arr] + if context_map is not None: + cmap_slices = [cmap_arr] + + if seg is not None: + seg_slices = [seg_arr] + + for idx, i in enumerate(slices): + x = torch.Tensor(i).to(self.device) + x = x.unsqueeze(0) + x = x.unsqueeze(0) + + if context_map is not None: + c = torch.Tensor(cmap_slices[idx]).to(self.device) + c = c.unsqueeze(0) + c = c.unsqueeze(0) + + x = torch.cat((x, c), dim=1) + + if seg is not None: + s = torch.Tensor(seg_slices[idx]).to(self.device) + s = s.unsqueeze(0) + s = s.unsqueeze(0) + + # Add in background channel + not_s = 1 - s.max(axis=1).values + not_s = torch.unsqueeze(not_s, dim=1) + s = torch.cat((not_s, s), dim=1).float() + + if self.hparams.prob_type == "prob": + if seg is not None: + self.prob_unet.forward(x, seg=s) + else: + self.prob_unet.forward(x) + + for sample in samples: + if self.hparams.prob_type == "prob": + if sample["name"] == "mean": + if seg is None: + y = self.prob_unet.sample(testing=True, use_mean=True) + else: + y = self.prob_unet.reconstruct(use_posterior_mean=True) + else: + if seg is None: + y = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + else: + y = self.prob_unet.reconstruct( + use_posterior_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + + # else: + # if sample["name"] == "mean": + # y = self.prob_unet.sample(x, mean=True) + # else: + # y = self.prob_unet.sample( + # x, + # mean=True, + # std_devs_from_mean=sample["std_dev_from_mean"], + # ) + + y = y.squeeze(0) + # y = np.argmax(y.cpu().detach().numpy(), axis=0) + y = torch.sigmoid(y) + sample["preds"].append(y.cpu().detach().numpy()) + + result = {} + for sample in samples: + pred_arr = sample["preds"][0] + + if self.hparams.ndims == 2: + pred_arr = np.expand_dims(pred_arr, 1) + if len(sample["preds"]) > 1: + pred_arr = np.stack(sample["preds"], axis=1) + + result[sample["name"]] = {} + + for idx, structure in enumerate(self.hparams.structures): + pred = sitk.GetImageFromArray(pred_arr[idx + 1]) # Skip the background + pred = pred > 0.5 # Threshold softmax at 0.5 + pred = sitk.Cast(pred, sitk.sitkUInt8) + + pred.CopyInformation(img) + pred = postprocess_mask(pred) + pred = sitk.Resample( + pred, img, sitk.Transform(), sitk.sitkNearestNeighbor + ) + + result[sample["name"]][structure] = pred + + return result + + def validate( + self, + img, + manual_observers, + samples, + mean, + matching_type="best", + window=[-0.5, 1.0], + ): + metrics = {"DSC": "max", "HD": "min", "ASD": "min"} + result = {} + + contour_cmaps = ["RdPu", "YlOrRd", "GnBu", "OrRd", "YlGn", "YlGnBu"] + structures = self.hparams.structures + + try: + cut = get_com(mean["mean"][structures[0]]) + except ValueError: + cut = [int(i / 2) for i in img.GetSize()][::-1] + + vis = ImageVisualiser(img, cut=cut, figure_size_in=16, window=window) + + mean_contours = {} + for idx, structure in enumerate(structures): + color_map = matplotlib.colormaps.get_cmap( + contour_cmaps[idx % len(structures)] + ) + mean_contours[f"mean_{structure}"] = mean["mean"][structure] + + vis.add_contour( + mean_contours, color=color_map(0.35), linewidth=3, show_legend=False + ) + + manual_color = color_map(0.9) + + manual_observers_struct = { + f"{man_struct}_{structure}": manual_observers[man_struct][structure] + for man_struct in manual_observers + } + + vis.add_contour( + manual_observers_struct, + color=manual_color, + linewidth=0.5, + show_legend=False, + ) + + intersection_mask = get_intersection_mask(manual_observers_struct) + union_mask = get_union_mask(manual_observers_struct) + + vis.add_contour( + intersection_mask, + name=f"intersection_{structure}", + color=manual_color, + linewidth=3, + ) + vis.add_contour( + union_mask, name=f"union_{structure}", color=manual_color, linewidth=3 + ) + + samples_struct = { + f"{sample_struct}_{structure}": samples[sample_struct][structure] + for sample_struct in samples + } + vis.add_contour( + samples_struct, + linewidth=1.5, + color={ + s: c + for s, c in zip( + samples_struct, + color_map(np.linspace(0.1, 0.7, len(samples_struct))), + ) + }, + ) + + # vis.set_limits_from_label(union_mask, expansion=30) + + sim = { + k: np.zeros((len(samples_struct), len(manual_observers_struct))) + for k in metrics + } + msim = { + k: np.zeros((len(samples_struct), len(manual_observers_struct))) + for k in metrics + } + for sid, samp in enumerate(samples_struct): + for oid, obs in enumerate(manual_observers_struct): + sample_metrics = get_metrics( + manual_observers_struct[obs], samples_struct[samp] + ) + mean_metrics = get_metrics( + manual_observers_struct[obs], mean_contours[f"mean_{structure}"] + ) + + for k in sample_metrics: + sim[k][sid, oid] = sample_metrics[k] + msim[k][sid, oid] = mean_metrics[k] + + result[f"probnet_{structure}"] = {k: [] for k in metrics} + result[f"unet_{structure}"] = {k: [] for k in metrics} + for k in sim: + val = sim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + prob_unet_mean = sim[k][row_idx, col_idx].mean() + else: + if metrics[k] == "max": + prob_unet_mean = val.max() + else: + prob_unet_mean = val.min() + result[f"probnet_{structure}"][k].append(prob_unet_mean) + + val = msim[k] + if matching_type == "hungarian": + if metrics[k] == "max": + val = -val + row_idx, col_idx = linear_sum_assignment(val) + unet_mean = msim[k][row_idx, col_idx].mean() + else: + if metrics[k] == "max": + unet_mean = val.max() + else: + unet_mean = val.min() + result[f"unet_{structure}"][k].append(unet_mean) + + fig = vis.show() + + return result, fig + + def training_step(self, batch, _): + x, c, y, m, _ = batch + + # Add background layer for one-hot encoding + not_y = 1 - y.max(axis=1).values + not_y = torch.unsqueeze(not_y, dim=1) + y = torch.cat((not_y, y), dim=1).float() + + # Concat context map to image if we have one + if c.numel() > 0: + x = torch.cat((x, c), dim=1).float() + + # self.prob_unet.forward(x, y, training=True) + if self.hparams.prob_type == "prob": + self.prob_unet.forward(x, y, training=True) + # else: + # self.prob_unet.forward(x, y) + np.save("yyyy.npy", y.cpu().detach().numpy()) + + if self.hparams.prob_type == "prob": + loss = self.prob_unet.loss(y, mask=m) + # else: + # loss = self.prob_unet.loss(x, y, mask=m) + + training_loss = loss["loss"] + + # Using weight decay instead + # if self.hparams.prob_type == "prob": + # reg_loss = ( + # l2_regularisation(self.prob_unet.posterior) + # + l2_regularisation(self.prob_unet.prior) + # + l2_regularisation(self.prob_unet.fcomb.layers) + # ) + # training_loss = training_loss + 1e-5 * reg_loss + self.log( + "training_loss", + training_loss.detach(), + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) + + self.kl_div = 1#loss["kl_div"].detach().cpu() + + for k in loss: + if k == "loss": + continue + if k == "kl_div": continue + self.log( + k, + loss[k].detach() if isinstance(loss[k], torch.Tensor) else loss[k], + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) + return training_loss + + def validation_step(self, batch, _): + if self.validation_directory is None: + self.validation_directory = Path(tempfile.mkdtemp()) + + n = self.hparams.num_observers + m = self.hparams.num_observers + + with torch.set_grad_enabled(False): + x, c, y, _, info = batch + + np.save("img.npy", x.cpu().numpy()) + print(f"x: {x.shape}") + print("y: " + str(y.shape)) + print(f"y1sum: {y.sum()}") + np.save("seg.npy", y.cpu().numpy()) + + # Save off slices/volumes for analysis of entire structure in end of validation step + for s in range(y.shape[0]): + img_file = self.validation_directory.joinpath( + f"img_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(img_file, x[s].squeeze(0).cpu().numpy()) + + if c.numel() > 0: + cmap_file = self.validation_directory.joinpath( + f"cmap_{info['case'][s]}_{info['z'][s]}.npy" + ) + np.save(cmap_file, c[s].squeeze(0).cpu().numpy()) + + mask_file = self.validation_directory.joinpath( + f"mask_{info['case'][s]}_{info['z'][s]}_{info['observer'][s]}.npy" + ) + np.save(mask_file, y[s].cpu().numpy()) + + # Image (and context map) will be same for all in batch + x = x[0].unsqueeze(0) + if c.numel() > 0: + c = c[0].unsqueeze(0) + if self.hparams.ndims == 2: + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0]), axis="z") + else: + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0, 0])) + + if self.hparams.ndims == 2: + x = x.repeat(m, 1, 1, 1) + + if c.numel() > 0: + c = c.repeat(m, 1, 1, 1) + else: + x = x.repeat(m, 1, 1, 1, 1) + + if c.numel() > 0: + c = c.repeat(m, 1, 1, 1, 1) + + if c.numel() > 0: + x = torch.cat((x, c), dim=1) + + seg = None + print(f"y2sum: {y.sum()}") + if self.use_structure_context: + not_y = 1 - y.max(axis=1).values + not_y = torch.unsqueeze(not_y, dim=1) + seg = torch.cat((not_y, y), dim=1).float() + print(f"seg1sum: {seg.sum()}") + np.save("seg2.npy", seg.cpu().numpy()) + + self.prob_unet.forward(x, seg=seg) + + loss = self.prob_unet.loss(seg) + print(f"VAL LOSS: {loss}") + + py = self.prob_unet.sample(testing=True) + py = py.to("cpu") + np.save("pred.npy", py.numpy()) + + pred_y = torch.zeros(py[:, 0, :].shape).int() + for b in range(py.shape[0]): + pred_y[b] = py[b, :].argmax(0).int() + + y = y.squeeze(1) + y = y.int() + y = y.to("cpu") + + # TODO Make this work for multi class + # Intersection over Union (also known as Jaccard Index) + jaccard = JaccardIndex(num_classes=2) + term_1 = 0 + for i in range(n): + for j in range(m): + if pred_y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], y[j]) + term_1 += 1 - iou + term_1 = term_1 * (2 / (m * n)) + + term_2 = 0 + for i in range(n): + for j in range(n): + if pred_y[i].sum() + pred_y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], pred_y[j]) + term_2 += 1 - iou + term_2 = term_2 * (1 / (n * n)) + + term_3 = 0 + for i in range(m): + for j in range(m): + if y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(y[i], y[j]) + term_3 += 1 - iou + term_3 = term_3 * (1 / (m * m)) + + D_ged = term_1 - term_2 - term_3 + + contours = {} + for o in range(n): + obs_y = y[o].float() + if self.hparams.ndims == 2: + obs_y = obs_y.unsqueeze(0) + contours[f"obs_{o}"] = sitk.GetImageFromArray(obs_y) + for mm in range(m): + samp_pred = pred_y[mm].float() + if self.hparams.ndims == 2: + samp_pred = samp_pred.unsqueeze(0) + contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) + + vis.add_contour(contours, colormap=matplotlib.colormaps.get_cmap("cool")) + vis.show() + + figure_path = f"ged_{info['z'][s]}.png" + plt.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + self.log("GED", D_ged) + + return info + + def validation_epoch_end(self, validation_step_outputs): + cases = {} + for info in validation_step_outputs: + for case, z, observer in zip(info["case"], info["z"], info["observer"]): + if not case in cases: + cases[case] = {"slices": z.item(), "observers": [observer]} + else: + if z.item() > cases[case]["slices"]: + cases[case]["slices"] = z.item() + if not observer in cases[case]["observers"]: + cases[case]["observers"].append(observer) + + metrics = ["DSC", "HD", "ASD"] + computed_metrics = { + **{ + f"probnet_{s}_{m}": [] for m in metrics for s in self.hparams.structures + }, + **{f"unet_{s}_{m}": [] for m in metrics for s in self.hparams.structures}, + } + + if len(cases) == 0: + return + + prob_surface_dice = 0 + prob_dice = 0 + + for case in cases: + img_arrs = [] + cmap_arrs = [] + cmap_arr = None + slices = [] + + if self.hparams.ndims == 2: + for z in range(cases[case]["slices"] + 1): + img_file = self.validation_directory.joinpath(f"img_{case}_{z}.npy") + if img_file.exists(): + img_arrs.append(np.load(img_file)) + slices.append(z) + + cmap_file = self.validation_directory.joinpath( + f"cmap_{case}_{z}.npy" + ) + if cmap_file.exists(): + cmap_arrs.append(np.load(cmap_file)) + + img_arr = np.stack(img_arrs) + + if len(cmap_arrs) > 0: + cmap_arr = np.stack(cmap_arr) + + else: + img_file = self.validation_directory.joinpath(f"img_{case}_0.npy") + img_arr = np.load(img_file) + + cmap_file = self.validation_directory.joinpath(f"cmap_{case}_0.npy") + if cmap_file.exists(): + cmap_arr = np.load(cmap_file) + + img = sitk.GetImageFromArray(img_arr) + img.SetSpacing(self.hparams.spacing) + + observers = {} + for _, observer in enumerate(cases[case]["observers"]): + if self.hparams.ndims == 2: + mask_arrs = [] + for z in slices: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + + mask_arrs.append(np.load(mask_file)) + + mask_arr = np.stack(mask_arrs, axis=1) + + else: + mask_file = self.validation_directory.joinpath( + f"mask_{case}_{z}_{observer}.npy" + ) + mask_arr = np.load(mask_file) + + observers[f"manual_{observer}"] = {} + for idx, structure in enumerate(self.hparams.structures): + mask = sitk.GetImageFromArray(mask_arr[idx]) + mask = sitk.Cast(mask, sitk.sitkUInt8) + mask.CopyInformation(img) + observers[f"manual_{observer}"][structure] = mask + + context_map = None + if cmap_arr is not None: + context_map = sitk.GetImageFromArray(cmap_arr) + context_map.SetSpacing(self.hparams.spacing) + + seg = None + if self.use_structure_context: + # TODO choose the observer to pass properly + seg = observers[f"manual_{observer}"][structure] + + try: + mean = self.infer( + img, + context_map=context_map, + seg=seg, + num_samples=1, + sample_strategy="mean", + preprocess=False, + ) + samples = self.infer( + img, + context_map=context_map, + seg=seg, + sample_strategy="spaced", + num_samples=5, + spaced_range=[-2, 2], + preprocess=False, + ) + except Exception as e: + print(f"ERROR DURING VALIDATION INFERENCE: {e}") + return + + + # try: + result, fig = self.validate( + img, observers, samples, mean, matching_type="best" + ) + # except Exception as e: + # print(f"ERROR DURING VALIDATION VALIDATE: {e}") + # return + + figure_path = f"valid_{case}.png" + fig.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + for t in result: + for m in metrics: + computed_metrics[f"{t}_{m}"] += result[t][m] + + # Compute the probabilistic (surface) dice + for idx, structure in enumerate(self.hparams.structures): + gt_labels = [] + for _, observer in enumerate(cases[case]["observers"]): + gt_labels.append(observers[f"manual_{observer}"][structure]) + + sample_labels = [] + for rk in samples: + sample_labels.append(samples[rk][structure]) + + prob_dice += probabilistic_dice( + gt_labels, sample_labels, dsc_type="dsc" + ) + prob_surface_dice += probabilistic_dice( + gt_labels, sample_labels, dsc_type="sdsc", tau=3 + ) + + prob_dice = prob_dice / len(cases) + if np.isnan(prob_dice): + prob_dice = 0 + self.log( + "probabilisticDice", + prob_dice, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + prob_surface_dice = prob_surface_dice / len(cases) + if np.isnan(prob_surface_dice): + prob_surface_dice = 0 + self.log( + "probabilisticSurfaceDice", + prob_surface_dice, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + if self.kl_div: + p = u = 0 + for s in self.hparams.structures: + p += np.array(computed_metrics[f"probnet_{s}_DSC"]).mean() + u += np.array(computed_metrics[f"unet_{s}_DSC"]).mean() + + p /= len(self.hparams.structures) + u /= len(self.hparams.structures) + computed_metrics["scaled_DSC"] = ((p + u) / 2) + (p - u) - self.kl_div + + for cm in computed_metrics: + self.log( + cm, + np.array(computed_metrics[cm]).mean(), + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + ) + + # shutil.rmtree(self.validation_directory) + + +def main(args, config_json_path=None): + pl.seed_everything(args.seed, workers=True) + + args.working_dir = Path(args.working_dir) + args.working_dir = args.working_dir.joinpath(args.experiment) + # args.default_root_dir = str(args.working_dir) + args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") + args.default_root_dir = str(args.fold_dir) + args.accumulate_grad_batches = {0: 5, 5: 10, 10: 15} + # args.precision = 16 + + comet_api_key = None + comet_workspace = None + comet_project = None + + if args.comet_api_key: + comet_api_key = args.comet_api_key + comet_workspace = args.comet_workspace + comet_project = args.comet_project + + if comet_api_key is None: + if "COMET_API_KEY" in os.environ: + comet_api_key = os.environ["COMET_API_KEY"] + if "COMET_WORKSPACE" in os.environ: + comet_workspace = os.environ["COMET_WORKSPACE"] + if "COMET_PROJECT" in os.environ: + comet_project = os.environ["COMET_PROJECT"] + + if comet_api_key is not None: + comet_logger = CometLogger( + api_key=comet_api_key, + workspace=comet_workspace, + project_name=comet_project, + experiment_name=args.experiment, + save_dir=args.working_dir, + offline=args.offline, + ) + if config_json_path: + comet_logger.experiment.log_code(config_json_path) + + dict_args = vars(args) + + data_module = UNetDataModule(**dict_args) + + prob_unet = ProbUNet(**dict_args) + + if args.resume_from is not None: + trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) + else: + trainer = pl.Trainer.from_argparse_args(args) + + if comet_api_key is not None: + trainer.logger = comet_logger + + lr_monitor = LearningRateMonitor(logging_interval="step") + trainer.callbacks.append(lr_monitor) + + # Save the best model + if args.checkpoint_var: + checkpoint_callback = ModelCheckpoint( + monitor=args.checkpoint_var, + dirpath=args.default_root_dir, + filename="probunet-{epoch:02d}-{" + args.checkpoint_var + ":.2f}", + save_top_k=1, + mode=args.checkpoint_mode, + ) + trainer.callbacks.append(checkpoint_callback) + + if args.early_stopping_var: + early_stop_callback = GECOEarlyStopping( + monitor=args.early_stopping_var, + min_delta=args.early_stopping_min_delta, + patience=args.early_stopping_patience, + verbose=True, + mode=args.early_stopping_mode, + ) + trainer.callbacks.append(early_stop_callback) + + trainer.fit(prob_unet, data_module) + + +def parse_config_file(config_json_path, args): + with open(config_json_path, "r") as f: + params = json.load(f) + for key in params: + args.append(f"--{key}") + + if isinstance(params[key], list): + for list_val in params[key]: + args.append(str(list_val)) + else: + args.append(str(params[key])) + + return args + + +if __name__ == "__main__": + args = None + config_json_path = None + if len(sys.argv) == 2: + # Check if JSON file parsed, if so read arguments from there... + if sys.argv[-1].endswith(".json"): + config_json_path = sys.argv[-1] + args = parse_config_file(config_json_path, []) + + arg_parser = ArgumentParser() + arg_parser = ProbUNet.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument( + "--seed", type=int, default=42, help="an integer to use as seed" + ) + arg_parser.add_argument( + "--experiment", type=str, default="default", help="Name of experiment" + ) + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", nargs="+", type=float, default=[1, 1, 1]) + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + arg_parser.add_argument("--early_stopping_var", type=str, default=None) + arg_parser.add_argument("--early_stopping_min_delta", type=float, default=0.01) + arg_parser.add_argument("--early_stopping_patience", type=int, default=50) + arg_parser.add_argument("--early_stopping_mode", type=str, default="max") + arg_parser.add_argument("--checkpoint_var", type=str, default=None) + arg_parser.add_argument("--checkpoint_mode", type=str, default="max") + + parsed_args = arg_parser.parse_args(args) + + # Check if config arg parsed, if so take over values and reparse + if parsed_args.config: + print("parseing args") + args = parse_config_file(parsed_args.config, sys.argv[1:]) + parsed_args = arg_parser.parse_args(args) + + main(parsed_args) diff --git a/platipy/imaging/cnn/train_lidc.py b/platipy/imaging/cnn/train_lidc.py new file mode 100644 index 00000000..7a8132ba --- /dev/null +++ b/platipy/imaging/cnn/train_lidc.py @@ -0,0 +1,1451 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import tempfile +import json +import logging + +from pathlib import Path +import SimpleITK as sitk +import numpy as np +from scipy.optimize import linear_sum_assignment + +import comet_ml # pylint: disable=unused-import +from pytorch_lightning.loggers import CometLogger +from torchmetrics import JaccardIndex +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint + +from argparse import ArgumentParser + +import matplotlib.pyplot as plt + +from platipy.imaging.cnn.prob_unet import ProbabilisticUnet +from platipy.imaging.cnn.hierarchical_prob_unet import HierarchicalProbabilisticUnet +from platipy.imaging.cnn.unet import l2_regularisation +from platipy.imaging.cnn.dataload import UNetDataModule +from platipy.imaging.cnn.dataset import crop_img_using_localise_model +from platipy.imaging.cnn.utils import preprocess_image, postprocess_mask, get_metrics + +from platipy.imaging import ImageVisualiser +from platipy.imaging.label.utils import get_com, get_union_mask, get_intersection_mask + + +import random +import math +from pathlib import Path + +import torch + +import pytorch_lightning as pl + +from platipy.imaging.cnn.lidc_dataset import LIDCDataset +from platipy.imaging.cnn.sampler import ObserverSampler + +logger = logging.getLogger(__name__) + +LIDC_PAT_IDS = [ + '100036212881370097961774473021', + '100063870746088919758706456900', + '100196544861967692255437311943', + '100623635004206394377317082803', + '101324598070011890446155612648', + '101370605276577556143013894866', + '101493103577576219860121359500', + '101860660338020663891435526309', + '102630309685058068986252460648', + '102727149680558741079248783846', + '103314508209751760357909975504', + '103418693958392288272205189260', + '103909846248671377349144781232', + '103936837659251066416686642502', + '104132635304747270459442606987', + '104261792177613018096518227983', + '105788018052982730868269746567', + '106068656437895419353734796481', + '106078428010894478241933303675', + '106200285452190102759452314796', + '106510194697877936244358836587', + '106941897938005462856300775001', + '107063920355575867343721608071', + '107117680739311641559517949056', + '107410301078591208486461903808', + '107508721613473648848595116695', + '107781084804242272261405060893', + '108196413521053893318500192469', + '108532092918956660012164380891', + '108965194825990605371103747168', + '109276391800556245215148059611', + '109666061122737469139155049085', + '109666408795826210278851151448', + '110552610218214013901193645855', + '110583593061096148970283270811', + '110814261083396253384916416499', + '111191540532827735098734161238', + '111578894836113526186151044207', + '111854046978561013667340276013', + '112366982107658327153398402250', + '112657410053794387237858900758', + '113336731083882049431478817978', + '113559330012918930951543504919', + '113703017711510899918035270531', + '113707792304333965399737722872', + '114128037878340599989460011716', + '114825706723542293203506302449', + '115235596345340852495791642145', + '115862643355747092400544598419', + '116398718009259533261076124801', + '117534698898058162410044712500', + '117743534419290251609867537261', + '117809518219141943884237457410', + '118873832573622894386164980989', + '119160670718020063800771208605', + '119329123269319944171115250782', + '120580422725898456426920474122', + '120781932093240979642841947845', + '120891719798974864125658202102', + '121477207174726219407790693667', + '123305427042753000123559340989', + '123740919388465597562183738584', + '123757952382514314958340753104', + '123906098414077028808869184162', + '124043676181995670004300402875', + '124508094927772723365648607398', + '124918015490899130716961145499', + '125091393024330590664846028519', + '125196609469610531868822891452', + '125755525194735914218741665527', + '125968645555307955291497942287', + '126579974603592904802118981222', + '126627335783756807832314347589', + '127299727722058129007613721740', + '127335250316932247803669656678', + '127418142282097214960207117703', + '128173929531863533946053948407', + '128331462267861650186674238297', + '128546361353633775045659990592', + '128707571431497524813677196205', + '128882236014679059490275573019', + '129172271455973626256608772850', + '129889229961080005326707611943', + '129896477022915871974945777250', + '130212238381185558806293057208', + '130369694349415652282198282842', + '131383203689189807643685075952', + '131386286964058684099078579332', + '131639774933066433859447191174', + '132644395812981309476663139204', + '133142523452166818708847170969', + '134300369034461601145053122049', + '135015991565946540622375616104', + '135045503670480596826162811572', + '135295547724563274783029730811', + '135354380632609557236443907760', + '135441977655501389639333579222', + '135649280843401063699693855484', + '135664533107982155600588797111', + '136876570107341293094375256847', + '136980005417322283801813327480', + '137094437164923904545127071746', + '137641101085876511008939491396', + '137750025318220725626773657356', + '138141860630920227890886912798', + '138500869556929311152852869529', + '138662618939201961985267634802', + '138751917514248782010778675083', + '139007379520761595657549251638', + '139110171863263699469684411692', + '139195294068636036088709372879', + '139230809749319928621494922594', + '139267808951641148096439810470', + '140032007657380601732641072716', + '140071422310372235469229806680', + '140459664666732545368292166616', + '141598007834192132861466794478', + '141781718363585577687281036482', + '142292437379762052008677682241', + '142440312470956325774436777849', + '142946965450964505637435479607', + '143147006327116114851951568235', + '143774983852765282237869625332', + '143992300791476122502099615811', + '144067942028925306269987156361', + '144098736774350495825430776051', + '144115995771389118102515965956', + '145101705103753228347955694208', + '145271800733926000322166019837', + '145309426008342216455171944681', + '145373944605191222309393681361', + '145622720063142155634645333747', + '145643207559809787753499447041', + '145676616748203729170191829666', + '145713291593525654962776610448', + '146389062541391302265553091834', + '146673599499822520893810766696', + '146693794897525614810620472732', + '146874505599848236392134956519', + '147007934469588183141400115665', + '147078385001386003318643105546', + '147369388209402780868232397204', + '147379540710364690031394735651', + '147528196931888636188830549231', + '147704659627178252268482840005', + '147733809306552412615241520287', + '148358074732977192643732679421', + '149109655051916227035843294920', + '149666752178793201786367154542', + '150115359973664720902580965129', + '150164970015823494629986891791', + '150188222462440903489565305064', + '150264634200093580367988090366', + '150506552813421138982733853172', + '150518738304733101418007977406', + '150842260481112310322112796796', + '151462193009815033155890307894', + '153537389944109506778485451374', + '153879314331255637907074508526', + '154013297600241321274325492714', + '154309317539716408682560573374', + '154418190532000744482160157724', + '154819036420703297447036427861', + '154900488526447142456824787975', + '155364374491657276710118621544', + '155422883755338970256390193392', + '155557195003938299459901738073', + '155894685355411131558273211903', + '156049232240708034287535871530', + '156618458422978822822335971869', + '156996082231474939899578263697', + '157442654089163434314265938550', + '157659992364948575945704799673', + '157894141428308694229687552009', + '158361076621595636405240935776', + '158967312129487733823739945169', + '159154526664736300420420966290', + '159596875781736451820677432740', + '160151929696133590075333292104', + '160177241792963961453620813854', + '160725905583575318613483772884', + '160881090830720390023167668360', + '161049972809841608349356015239', + '161084022770774658371336888233', + '163285696979961262813032969899', + '163364452480273959392764002225', + '163747591665714002593484708741', + '164156833754689412076055510609', + '164204559177081679885683922807', + '164217405113124054884377740544', + '165017330457520545748655969705', + '165089822985195173225904535454', + '165584568388616938096482975079', + '165741627456266454825445080793', + '165928214371593461789742284562', + '166043246571762775528650569121', + '166045097359941678759735162107', + '166516471529644578497840692226', + '167008253410339916258991990447', + '167332257805620834964807748714', + '167385525980664447633063747843', + '167461514102912667251280826923', + '167481427179213606474786125326', + '167583044843097027161721296977', + '168697575496383193075236640012', + '168750903689425205479060653325', + '170453435914624225750000821602', + '172057344391312701982079788940', + '172781764368509306119706213242', + '172879081603366631101823390088', + '172891325013037275999502688439', + '173046050032662152275088620852', + '173476782913985657859750913739', + '173480979711457247360986415860', + '174771861434175096253509271253', + '174809695196160918760894715771', + '175760080411828194592094064838', + '176119253416372337672619835088', + '176189447010290638749878151482', + '176724808324080338407584179625', + '178008579082565292854358918057', + '178054575127307086914328123377', + '178055839088630125000227320316', + '178854189219941544047372620516', + '179835446111622550556046150972', + '179926192573878616596153758496', + '180089742036563526934101223693', + '180451931053147153984482896172', + '180685600263085599637486877417', + '180841215015376467807831976209', + '181183885950568332347816930781', + '181398349112343043834924846577', + '181417885049753012096413476972', + '181593981812293966599842811096', + '181785797305004428693015149455', + '182110250999021699484937823450', + '182190805623310236601030915541', + '182529561822199780545440408033', + '183088557662694130014178072362', + '183948389971742539190611659002', + '184902022874872369505485150214', + '185466477130333115741807427373', + '185496461202477950899792771158', + '185621295429000591339779532687', + '185810436275168701789786930141', + '185852472264155752604841654643', + '186165274648953680758061039841', + '186309477398566417518475261664', + '186378343509510052901886052675', + '186545761857837015081715243889', + '187400292293901990471724598227', + '188204058669364993757306038569', + '189280790286095758626785941529', + '189670740451073875347541923169', + '189690798782956693829000203312', + '190036064555105762519342089670', + '190045142764457963256735097041', + '190188259083742759886805142125', + '190578678221433604759795450204', + '190722769298607626547238572819', + '190773468090009472439928213592', + '190838841811080874400532510738', + '191073955214402396197702655759', + '191294111388782437161651532011', + '191417614594472789714997861946', + '191425307197546732281885591780', + '192544770425703459348950670623', + '192711002126941665635416825711', + '193183513859852492693631043956', + '193241055656414949090207821605', + '193262972800234484744399468745', + '193438875802960011844317570223', + '193632643404878480165927810858', + '194041219679835518833081477370', + '194069107440754900812783343110', + '194264995369749234588783691976', + '194408227371077105112329346518', + '194592222324100678493470529736', + '194890498394311571938776742465', + '194989275190883931437102127157', + '195049537858207829551926654069', + '195418537319448622946575911056', + '195536973222079080038674523383', + '195975724868929317649402600442', + '196064403160663787222867039271', + '196095989034530030261387979708', + '196764647858692982718767974745', + '198021313421875017435046583973', + '198480745206106280149820227940', + '198938890575750893735876189224', + '199795594928180609213453658461', + '200080981983036698023153287807', + '200388868639180278429905174768', + '200556866344536431817884531386', + '201356741575321997579793754679', + '201981314170126152586920703158', + '202012743926377385386645171009', + '202087967488314629195929644257', + '202128825827009854792580489975', + '202313610740309364769212928365', + '202796045319384029259715575526', + '202827697628802454715911097315', + '202972199430488069905719748096', + '203021212653448281775262541232', + '203511118110447056499830468108', + '203745372924354240670222118382', + '203918583798186026281890202047', + '204167514516464519480528962273', + '204769919102665056662271574089', + '204876032380829136260582432402', + '205265707007087091421912273273', + '205295494158831463556858756520', + '207817820000493988193034888372', + '208001565962486054565606721023', + '208044717702980576765085373999', + '208177797605474151106520124306', + '208253527688434910178787865281', + '208376669554572460085205852204', + '208962973581011466719041210639', + '209628614952351211282586434983', + '210105060472758363785916556991', + '210438744022591353484036052860', + '210460865401632134346677582159', + '211028058202430803930155553655', + '211053720209798423692283723094', + '211186311281767598090156083844', + '211592828197214585826096300114', + '211678183477028942594541249344', + '212173791575971142588528365479', + '212292046142156223429795319169', + '212341120080087350703610584139', + '212697393127299815450339637649', + '212995606100714447928555271551', + '213021056957630403662329810457', + '213021675581421639588001132423', + '213637791937214112275825554647', + '213690221459882536855619399918', + '213735911041216674787295333807', + '213747445868893344159976127183', + '214011375577453036409274126845', + '214246516675840948112237158320', + '214565862082106444257760478581', + '214802309639542154012576451584', + '215072449579869595247908643994', + '215530681893783493990923094397', + '215559453287121684893831546044', + '215789618966985272514045108451', + '216284625800598617220647330177', + '216376032357092323639269932442', + '216596541995300108433152752372', + '217723017594466509330246364292', + '217795632219483181846615097960', + '218334290952989306683953951600', + '218353171044902873293532449170', + '218441204865582835284481030822', + '218658642102832118810712329678', + '219081768759399413645906161308', + '219248583669253502298142724766', + '219603183011909533511019563626', + '220214273010365852692388401870', + '220250766483775468847898872976', + '220318312497776009932500036468', + '220440176952423273988539874706', + '220989155161038735570528881205', + '221121159879001927899714055324', + '221138977753587043062153841454', + '221467562933372960599414594591', + '222363923835851360513495795866', + '222444895275634974977432374937', + '222626438572167217610357947558', + '222993894345836169253703951249', + '223586178787100112140685204730', + '223728112116674740171130784710', + '223929343925645778515146916755', + '224816556084824998460890400016', + '224985459390356936417021464571', + '225213110794629789874295007045', + '225325726732923728196849027710', + '225336514441526266989689082116', + '225520892751936930714518753705', + '225862377550650653917701158715', + '225952820864846070881460157728', + '226088449233409636933226805676', + '226402817998131997261275093738', + '226719444846209417020566423366', + '227040504303112430978662427134', + '227063200244152860572015411033', + '228681789686851016136117645559', + '229037285595999197108233114060', + '229388260375062685257395968125', + '229453814067948782185747812606', + '229800514909823088973473373045', + '229967962418363878052269915938', + '230350651484561219081974325611', + '230800355892152356700875480882', + '230901123329037029807195618747', + '231462296937187240061810311146', + '232333503337684976076888189581', + '232484282311807761778575987163', + '232772963487864165109888489609', + '233185200895881555317060584785', + '233265558039230075858412321938', + '233360976397162618015897824056', + '233972273372473359632172336060', + '234120235997503955698755550667', + '234289514191030145998276287188', + '234418158081986309761071794125', + '234565379252678413221170803425', + '234994879536134968838889803884', + '235161199041083161581467190209', + '235632691141320250472721126959', + '235969059582659051462660101595', + '236473827829552232243074977782', + '236694495816915149507620783797', + '236722533610330821341378792527', + '239129180722773562585191194111', + '240259014486900472744222753388', + '240378069135678893749859521194', + '241438921576606444541425654823', + '241689052849438260084754220656', + '241863405068914130326466911331', + '242293704887164773216756108679', + '242682647435854260754234965628', + '243200997268477396967173500055', + '244389108887315437307204657105', + '244560540649171231768816505483', + '245332552721136420493204351294', + '246309121384607847712377186777', + '246541809203697037171004327898', + '246917032284352184215612201998', + '247360966390848191713625967090', + '247647652518040926954657385575', + '247813601128055334769770235073', + '248466778281073632010340585979', + '248536440095859239568984753729', + '248655117969906530264763921998', + '248983002956469734140739709749', + '249778071616205114955151949358', + '250118712805848631355370259962', + '250770014904528873190814943829', + '250807847445096579567542381073', + '252510005003229547978208913780', + '252807379241747927058486908845', + '253018306736883838620886843861', + '253123637959400428278268237946', + '253735246078920128978742275635', + '254284452408440744726985192393', + '254597093814481056655048098352', + '254734404186354295325724812720', + '255547294378103412654877543117', + '255763746952382642554143442493', + '255996783379921037765414923334', + '256230432050930445318439765836', + '256415982022395230813284649190', + '257497390857480311387964520553', + '259516711483314195853998851316', + '259564082580389615261241066925', + '260157243922037600449005570857', + '260650317266630055112768212364', + '261049959740626621730647171520', + '261859010096419762853634942928', + '261881453527700554965645624212', + '261896829953142370809123984374', + '262233919583075373552810415567', + '262372993180687488521371393069', + '262595324886254011362748468025', + '262752628257670445166772606344', + '262940731886389902640730271210', + '263031755598034095800749048602', + '263062819617890562693932243280', + '264122020925099449244490627864', + '264393460861616413192668347002', + '265116643168672454267341689554', + '265390640843728236501931252310', + '265657827498396379702226874113', + '265704884949271879044145982159', + '267013157670921984098319605661', + '267153963553416618872924015484', + '267214856456387865154306936080', + '267354935124724968585540389507', + '267430321455341154577885873460', + '267495169884268604035801498197', + '267607850463096817515575285607', + '267745911637906357699108280139', + '268186928222507173063266624020', + '268490252284941048496256182524', + '269532656350104249022826895294', + '269549495557169661890212203300', + '269815821605052946328618031845', + '269884215256755446329937557120', + '270093416536223137275197282336', + '270449468111534490183179612522', + '271910441918529291689264844963', + '271916801059080642953570118476', + '272102141116238598947111407982', + '272416900158314679872946504460', + '272747692608963634087245721224', + '272800023895682200488000769712', + '273398739486305201304583240983', + '273873499396396267864941267496', + '274533538808543280646156261676', + '274709137328120797052487052456', + '275166955484270759335764642842', + '275342292592333052340429698533', + '275399413024322303306387851279', + '275597169013178714160720614233', + '275892948066257709459536915546', + '276145303905683401700212326794', + '276357014804464864262939025676', + '276910031320955223579056176086', + '277707646521145217625235879315', + '277917899093238118799634234844', + '278314508776467367928735781601', + '278388415999131655907385534718', + '278420383577925795821592801973', + '278571116570364233687639963534', + '279133908903915300821612602665', + '279133997079513544257388158962', + '279815994089337890330418719400', + '279896786227805041031596230124', + '279957551156089498592959292779', + '279988868485323680600035963215', + '280050083021766496264016832213', + '280315210397549164238230581781', + '280531413986295071283803322793', + '280944743442493595294591879190', + '280963839031344532519735389631', + '281149709266012657212003921035', + '281491059962512876018428810934', + '281499745765120562304307889347', + '282376078213521832826757169162', + '282592678243385281753684471720', + '283065772077726118787509535476', + '283314504686117114905176791940', + '283363009241515390315316302638', + '283388878260039881020061501616', + '283747814329309197645868691560', + '284028192986684647771628933377', + '284088385995679184706750192135', + '284091129290542528109405627934', + '286257039861864871390665013995', + '286786177811953810051065476115', + '287560874054243719452635194040', + '288142673190348607009892502290', + '288625283139929827339044364850', + '288836300158326561947306862905', + '289084067435838460184299744241', + '289145216356920254799922475037', + '289514702605693988808674635347', + '289759534306508848209269388603', + '289886930576192132696568423166', + '290602873613899406112457532639', + '291105837361929821655470189849', + '291351987478858068691768357060', + '291817048149111149961802063949', + '291896014000911872483739610441', + '292207153428070172322351824035', + '292226263060859572730824828126', + '292267299840811335940503268889', + '292294220516976158142911380662', + '292417350644573735869959859895', + '292628672046109312619048073568', + '293433201388001070931182667378', + '293760350481778481843102622378', + '294168619890384187906139486190', + '294653264607473621174579354155', + '295517876835894730236647974030', + '297192539691487434404158853083', + '297491710261529399075427138612', + '297522203193490519910431870739', + '297927758886508342527732573281', + '298180121619674780553933016732', + '298731156484751583249477309475', + '298806137288633453246975630178', + '299122687641741853427119259207', + '299257181285137013436464151874', + '299799877133044736642536495362', + '300409677036863686976674785893', + '300568323537528705778699437287', + '300829918445389512656506955074', + '301465695265899538081208550111', + '301771434938229546704232023262', + '301893285809694674511225349300', + '302256916113808751290358737152', + '302506223349239046044276337140', + '302617159747584825959135390069', + '303099231937480740934110243375', + '303241414168367763244410429787', + '303407883137142435506738687070', + '303408504856716615682692778690', + '303856284614793157222299489373', + '303921580531000844896433537490', + '303924616208074142487120966776', + '304088547901303960997044129270', + '304128927772479718113589870111', + '305011216002720504189976406136', + '305196230708291505432654379176', + '305234403531996676303325051420', + '305703349227966644337102413678', + '305756441152306462678795884038', + '305863253247137744276642948253', + '305942859345281883515236698139', + '306209107793490820623279780488', + '306223432026099623241869333271', + '307585785421640145685948300552', + '308816704145697573874568073799', + '309118579255866880117522688599', + '309564852246680613688744241427', + '310090830439147530838142668838', + '310500752247645337767427178242', + '310641973458012698326158285915', + '310869628943317589181434675447', + '310943546307547399461310181653', + '311286315716340154235274558813', + '311453461766047239391876458979', + '311543823141683719978679431015', + '311849240521371267537044867782', + '311959464221077415127494283158', + '313434017581738347830026996031', + '313935464622070730633776692194', + '314071149857693870088107222867', + '314119990938826728339463484244', + '314138616411061948052843767346', + '314166198948235683770741486536', + '314311311148468433820105763522', + '315896257598312492181693323227', + '316049072680605749900519988649', + '316334611043225263064716257742', + '316733915710399203923795639422', + '317269195393986070164185617972', + '317273043933560935570616831007', + '317889411189319524523491450294', + '318162998398046037585497173804', + '318267749315379295095253644829', + '318387528257740161829281579408', + '319426861743264377998637994320', + '319437016886652687834302851680', + '319457039993091948420654822534', + '320868967976297780563487478878', + '321085339464682432111441689315', + '322060710223427694979493449810', + '322126192251489550021873181090', + '322524653283033873625769278172', + '322604336063659838243542603396', + '322613172334101771105940174946', + '322995703216827528672839235412', + '323012452641612668480978720171', + '323769537371782501749772855211', + '324291021807344718518738544991', + '324680252006411183918098592500', + '324827271082263044582582668595', + '324984193116544130562864675468', + '325608968275343627520297346991', + '325830590463177737667102135732', + '326240009912333401253143950265', + '326258759789625717227327619747', + '326975664298894323901926187239', + '327389263896737471496228866506', + '327833349640691088534060439127', + '328557408996437290097300166360', + '328985441791613247896558648523', + '329219900591912979945974892373', + '329360542712571362247573112426', + '329498168249206597932161726910', + '329613415894061309409362825998', + '330311286390409730062500041756', + '330694158626612266669807704245', + '330786107789886395527713704821', + '331662654015358587276208254750', + '331912290213688157146431129496', + '332442253024965725100180832377', + '333224958421615824054029320306', + '333362756208643390458434024958', + '333911857799694588686208472307', + '334142682219743556615674397992', + '334404540722450737157825443529', + '335511385175207746688439743644', + '335945989892274983962989644650', + '336137933660116977458622909107', + '336271942450113106996588030279', + '337712862699308652663458299448', + '338336992240867140131763172276', + '338445093314190310902086699869', + '338681063148697211042842329869', + '339170810277323131167631068432', + '339555666338359721401666777459', + '339975625902908481435949410827', + '349215784927595126067491535375', + '349904517868103143773797132680', + '359494088606212053886005767834', + '361004729625204263297457202086', + '370700630609225608130630902041', + '371593558888147552115426364555', + '373342454335320815569534817014', + '377835929784075736831041456357', + '378894347946872778030844834278', + '382216295557686059503096506837', + '385976057629453563236462196337', + '389796291341364059012825103262', + '392974922683023240281753530297', + '402240049299350560004923763412', + '403432615973617823449253024665', + '407122010994373607607380330242', + '419136021853948294349111583097', + '426292960722746490523326036933', + '427346572209935588501414358601', + '428749210060996400535779285747', + '429699865856208366565779600254', + '434247094534802729512368079584', + '434847191991072856104231378329', + '437063545936574671385154679589', + '440475654600005860316686606177', + '442836465012574652961690947717', + '443753256239015192514766534461', + '446141203915389353091048690870', + '456657617221039306684558366650', + '461502489756528153281751961352', + '467953176348008114522042094789', + '471106432014627176550866096857', + '473852308535992451018178219915', + '475498514082515864100304794963', + '477738775879476870029318616682', + '481620140149228611720235499832', + '490157381160200744295382098329', + '495416793311762759653372975706', + '498311492348035259054610823788', + '503842574668180938785276053000', + '508213936241272779545636798976', + '514061889253028999256401521179', + '518232152391739792448041571173', + '520374343116562553369175888897', + '520765263344096922537272606423', + '521368364741046707593021144453', + '522940913934103637427459195079', + '530012655070930408996523309860', + '531322615015231731672909400685', + '534939840290931669471335109271', + '541564533926823186990087847268', + '542174473570958426088104535645', + '547848375366758290042983535032', + '553558667874341556649497579315', + '554408346761971894109361624151', + '554779663962875550258283150632', + '558678722336494515704990369281', + '559695947493996195628793160329', + '559812581756025353887436599032', + '560355181698089203099323359955', + '561164700626233888666112673613', + '562420213730031312300572553711', + '563554930011057425448972918752', + '567798120866426088049494542765', + '568590496455752271550307195037', + '573428694448853086196304337865', + '577239032421661596436882763701', + '579308169960338143353751288115', + '584233139051825667176600857752', + '593684538737645231443136404239', + '598304550863987021597036432704', + '605832828213886887633700370869', + '607604221900107393271004692863', + '612557072934890742784673621068', + '614525372854817919294823651430', + '617121763259379501895075067753', + '618131472311494996047749006675', + '618911015325871816381733899879', + '622768579679891665663217992331', + '626070671580164939197414981549', + '633670303116637733124704376946', + '635629010879455042713206694910', + '636909670997007752324450926079', + '637798773907788996937534925551', + '639539378810400142491972900031', + '651527937137773549514457517157', + '657015929286431430594491075660', + '660741147026674947967595590708', + '662868666186676193600672178489', + '663171561625683396375993994291', + '669072071991651987864331187685', + '670416970629038619301651972112', + '670760254175238186822024748574', + '676549258486738448212921834668', + '679317079042102339180690634322', + '684533679602381775258237049962', + '684596493379395503904158873210', + '690051876029508763760267564639', + '691546855561342606351412524861', + '694111741408084165729604375833', + '694589790483267758352592215517', + '701448858177597192177075232452', + '704695217041193188031506800434', + '708247959534230054025514993152', + '709632041446581971088663252665', + '709632090821449989953075380168', + '709924281509506265799179464374', + '713948008609685722859788330157', + '716587475529555681439302477222', + '719888539703313386206966716806', + '721161755719965069770622927051', + '726410718508119500341020689167', + '728284743932342406301468721019', + '731392422930133359505662856851', + '732422099119935556160069769037', + '733596311057474916459070445858', + '742963139611542313018036607058', + '744833949134301384249366916145', + '748854471032735053174016711473', + '751341984016651692739143511794', + '751804495161511023323654179121', + '752603750437458841409277737148', + '760852670133552234806401724338', + '765927507488089285529132136179', + '765963830187381205722436255156', + '774224171618605001583693363297', + '774862199189105930723757496082', + '777672393907254490640128410455', + '780121538270617222115253441513', + '783812685850339917048882413972', + '788972240715000723677133060452', + '793828050288475038053022222683', + '803342036141517365362713834879', + '807635771657676107005923650811', + '807847494300907496059830326562', + '808134164613518601297092139901', + '815399168774050638734383723372', + '818004751275342069818453790263', + '823962363630113925122585468962', + '826821417415213438164503950143', + '837252783245693412667023877670', + '837260510399523070637863582641', + '837471965665975230389385298092', + '843094005962961412710063697075', + '847046315234934266835024451770', + '848745302315101424036793123771', + '849069697860879761549990488101', + '849090059968404135199482374496', + '849666522017598388280673804221', + '850568285181410483219385441815', + '858300670101158401140123918818', + '859733375254493048647590696534', + '867436015578673140421767840022', + '871736676471446897375109464440', + '879950425588844557177426831791', + '880369233375230345237001511006', + '886180838786633773936677813818', + '886430358468567310311007723004', + '887916493746193939407481623391', + '888021904600511420323095129935', + '888517498954149177086283916722', + '891182989185983545761655978623', + '897705953598294772269569489281', + '899082900417573006084750602123', + '899353684702041035700102438716', + '901103510796290218917651903823', + '902763751470786794946912471631', + '906824386019316813789030483695', + '908741193082513651836950434578', + '911801947447849749468305764840', + '915736860556289509455669530049', + '915986308688735366393353350740', + '919002813906622793381125049416', + '921287276013810837359841315530', + '924939006160714533549353726515', + '925679448263116681180549137562', + '927075281189608119735911336961', + '931660023131522836511470299550', + '934701751347399243333120058853', + '939103340398727679812199945201', + '952440288393343800284327753087', + '965523656856760127560055059644', + '969325292841504805778529336047', + '971476961920773226447199844576', + '988068515766013782236551550185', + '989440509183467842001314342301', + '995561512722026805270815340218', + '998144378008088787705870980497' +] + +#LIDC_PAT_IDS = [ +# "208177797605474151106520124306", +# "234289514191030145998276287188", +# "255763746952382642554143442493", +# '879950425588844557177426831791', +# '880369233375230345237001511006', +# '886180838786633773936677813818', +# '886430358468567310311007723004', +# '887916493746193939407481623391', +# '888021904600511420323095129935', +# '888517498954149177086283916722', +# '891182989185983545761655978623', +# '897705953598294772269569489281', +# '899082900417573006084750602123', +# '899353684702041035700102438716', +# '901103510796290218917651903823', +# '902763751470786794946912471631', +# '906824386019316813789030483695', +# '908741193082513651836950434578', +# '911801947447849749468305764840', +# '915736860556289509455669530049', +# '915986308688735366393353350740', +# '919002813906622793381125049416', +# '921287276013810837359841315530', +# '924939006160714533549353726515', +# '925679448263116681180549137562', +# '927075281189608119735911336961', +# '931660023131522836511470299550', +# '934701751347399243333120058853', +# '939103340398727679812199945201', +# '952440288393343800284327753087', +# '965523656856760127560055059644', +# '969325292841504805778529336047', +# '971476961920773226447199844576', +# '988068515766013782236551550185', +# '989440509183467842001314342301', +# '995561512722026805270815340218', +#] + +class LIDCDataModule(pl.LightningDataModule): + """PyTorch data module to load LIDC data""" + + def __init__( + self, + working_dir: str = "./working", + augment_on_fly=True, + fold=0, + k_folds=5, + batch_size=5, + num_workers=4, + **kwargs, + ): + super().__init__() + self.working_dir = Path(working_dir) + + self.fold = fold + self.k_folds = k_folds + + self.train_cases = [] + self.validation_cases = [] + self.test_cases = [] + + self.augment_on_fly = augment_on_fly + self.batch_size = batch_size + self.num_workers = num_workers + + self.training_set = None + self.validation_set = None + self.test_set = None + + self.validation_data = [] + self.test_data = [] + + self.num_observers = 4 + + print(f"Training fold {self.fold}") + + @staticmethod + def add_model_specific_args(parent_parser): + """Add arguments used for Data module""" + parser = parent_parser.add_argument_group("LIDC Data Loader") + parser.add_argument("--augment_on_fly", type=bool, default=True) + parser.add_argument("--fold", type=int, default=0) + parser.add_argument("--k_folds", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=5) + parser.add_argument("--num_workers", type=int, default=4) + + return parent_parser + + def setup(self, stage=None): + + cases = LIDC_PAT_IDS + cases.sort() + random.shuffle(cases) # will be consistent for same value of 'seed everything' + cases_per_fold = math.ceil(len(cases) / self.k_folds) + + for f in range(self.k_folds): + + if self.fold == f: + val_test_cases = cases[f * cases_per_fold : (f + 1) * cases_per_fold] + + if len(val_test_cases) == 1: + self.validation_cases = val_test_cases + else: + self.validation_cases = val_test_cases[: int(len(val_test_cases) / 2)] + # self.validation_cases = val_test_cases[: 5] + self.test_cases = val_test_cases[int(len(val_test_cases) / 2) :] + else: + self.train_cases += cases[f * cases_per_fold : (f + 1) * cases_per_fold] + + print(f"Training cases: {self.train_cases}") + print(f"Validation cases: {self.validation_cases}") + print(f"Testing cases: {self.test_cases}") + + augment_on_fly = self.augment_on_fly + + self.training_set = LIDCDataset( + self.working_dir, + augment_on_fly=augment_on_fly, + case_ids=self.train_cases + ) + print(f"Training Set Size: {len(self.training_set)}") + self.validation_set = LIDCDataset( + self.working_dir, + augment_on_fly=False, + case_ids=self.validation_cases + ) + print(f"Validation Set Size: {len(self.validation_set)}") + self.test_set = LIDCDataset( + self.working_dir, + augment_on_fly=False, + case_ids=self.test_cases + ) + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.validation_set, + batch_sampler=torch.utils.data.BatchSampler( + ObserverSampler(self.validation_set, self.num_observers), + batch_size=self.num_observers, + drop_last=False, + ), + num_workers=self.num_workers, + ) + + +class ProbUNet(pl.LightningModule): + def __init__( + self, + **kwargs, + ): + super().__init__() + + self.save_hyperparameters() + + loss_params = None + + if self.hparams.loss_type == "elbo": + loss_params = { + "beta": self.hparams.beta, + } + + if self.hparams.loss_type == "geco": + loss_params = { + "kappa": self.hparams.kappa, + "clamp_rec": self.hparams.clamp_rec, + "clamp_contour": self.hparams.clamp_contour, + "kappa_contour": self.hparams.kappa_contour, + } + + loss_params["top_k_percentage"] = self.hparams.top_k_percentage + loss_params["contour_loss_lambda_threshold"] = self.hparams.contour_loss_lambda_threshold + loss_params["contour_loss_weight"] = self.hparams.contour_loss_weight + + if self.hparams.prob_type == "prob": + self.prob_unet = ProbabilisticUnet( + self.hparams.input_channels, + 2, + self.hparams.filters_per_layer, + self.hparams.latent_dim, + self.hparams.no_convs_fcomb, + self.hparams.loss_type, + loss_params, + 2, + ) + elif self.hparams.prob_type == "hierarchical": + self.prob_unet = HierarchicalProbabilisticUnet( + input_channels=self.hparams.input_channels, + num_classes=2, + filters_per_layer=self.hparams.filters_per_layer, + down_channels_per_block=self.hparams.down_channels_per_block, + latent_dims=[self.hparams.latent_dim] * (len(self.hparams.filters_per_layer) - 1), + convs_per_block=self.hparams.convs_per_block, + blocks_per_level=self.hparams.blocks_per_level, + loss_type=self.hparams.loss_type, + loss_params=loss_params, + ndims=2, + ) + + self.validation_directory = None + self.kl_div = None + + @staticmethod + def add_model_specific_args(parent_parser): + parser = parent_parser.add_argument_group("Probabilistic UNet") + parser.add_argument("--prob_type", type=str, default="prob") + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--lr_lambda", type=float, default=0.99) + parser.add_argument("--input_channels", type=int, default=1) + parser.add_argument( + "--filters_per_layer", nargs="+", type=int, default=[64 * (2 ** x) for x in range(5)] + ) + parser.add_argument("--down_channels_per_block", nargs="+", type=int, default=None) + parser.add_argument("--latent_dim", type=int, default=6) + parser.add_argument("--no_convs_fcomb", type=int, default=4) + parser.add_argument("--convs_per_block", type=int, default=2) + parser.add_argument("--blocks_per_level", type=int, default=1) + parser.add_argument("--loss_type", type=str, default="elbo") + parser.add_argument("--beta", type=float, default=1.0) + parser.add_argument("--kappa", type=float, default=0.02) + parser.add_argument("--kappa_contour", type=float, default=None) + parser.add_argument("--clamp_rec", nargs="+", type=float, default=[1e-5, 1e5]) + parser.add_argument("--clamp_contour", nargs="+", type=float, default=[1e-3, 1e3]) + parser.add_argument("--top_k_percentage", type=float, default=None) + parser.add_argument("--contour_loss_lambda_threshold", type=float, default=None) + parser.add_argument("--contour_loss_weight", type=float, default=0.0) # no longer used + parser.add_argument("--epochs_all_rec", type=int, default=0) # no longer used + + return parent_parser + + def forward(self, x): + self.prob_unet.forward(x, None, False) + return x + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.hparams.learning_rate, weight_decay=0 + ) + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=[lambda epoch: self.hparams.lr_lambda ** (epoch)] + ) + + return [optimizer], [scheduler] + + def frange_cycle_linear(self, start, stop, n_epoch, n_cycle=4, ratio=0.5): + L = np.ones(n_epoch) + period = n_epoch/n_cycle + step = (stop-start)/(period*ratio) # linear schedule + for c in range(n_cycle): + v , i = start , 0 + while v <= stop and (int(i+c*period) < n_epoch): + L[int(i+c*period)] = v + v += step + i += 1 + return L + + def training_step(self, batch, _): + + x, y, m, _ = batch + + # Add background layer for one-hot encoding + #y = torch.unsqueeze(y, dim=1) + not_y = 1 - y.max(axis=1).values + not_y = torch.unsqueeze(not_y, dim=1) + y = torch.cat((not_y, y), dim=1).float() + + if self.hparams.prob_type == "prob": + self.prob_unet.forward(x, y, training=True) + else: + self.prob_unet.forward(x, y) + + if self.hparams.prob_type == "prob": + beta_vals = self.frange_cycle_linear(0.0, 0.01, 100, 4, 1.0) +# loss = self.prob_unet.loss(y, mask=m, beta=beta_vals[self.current_epoch]) + loss = self.prob_unet.loss(y, mask=m) + else: + loss = self.prob_unet.loss(x, y, mask=m) + + training_loss = loss["loss"] + + if self.hparams.prob_type == "prob": + reg_loss = ( + l2_regularisation(self.prob_unet.posterior) + + l2_regularisation(self.prob_unet.prior) + + l2_regularisation(self.prob_unet.fcomb.layers) + ) + training_loss = training_loss + 1e-5 * reg_loss + self.log( + "training_loss", + training_loss.detach(), + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) + + self.kl_div = loss["kl_div"].detach().cpu() + + for k in loss: + if k == "loss": + continue + self.log( + k, + loss[k].detach() if isinstance(loss[k], torch.Tensor) else loss[k], + on_step=True, + on_epoch=False, + prog_bar=True, + logger=True, + ) + return training_loss + + def validation_step(self, batch, _): + + n = 4 + m = 4 + + with torch.set_grad_enabled(False): + x, y, _, info = batch + + # Image will be same for all in batch + x = x[0, :, :, :].unsqueeze(0) + vis = ImageVisualiser(sitk.GetImageFromArray(x.to("cpu")[0,:,:,:]), axis="z") + x = x.repeat(m, 1, 1, 1) + self.prob_unet.forward(x) + + py = self.prob_unet.sample(testing=True) + py = py.to("cpu") + # print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") + # pred_y = torch.sigmoid(pred_y) + # print(f"{pred_y[0,:,:,:].min()} {pred_y[0,:,:,:].max()}") + + # pred_y = pred_y[:,1,:,:] > 0.5 + # pred_y = pred_y.unsqueeze(1) + pred_y = torch.zeros(py[:,0,:].shape).int() + for b in range(py.shape[0]): + pred_y[b] = py[b,:].argmax(0).int() + + y = y.squeeze(1) + y = y.to("cpu") + + # Intersection over Union (also known as Jaccard Index) + jaccard = JaccardIndex(num_classes=2) + term_1 = 0 + for i in range(n): + for j in range(m): + if pred_y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], y[j]) + term_1 += 1 - iou + term_1 = term_1 * (2/(m*n)) + + term_2 = 0 + for i in range(n): + for j in range(n): + if pred_y[i].sum() + pred_y[j].sum() == 0: + continue + iou = jaccard(pred_y[i], pred_y[j]) + term_2 += 1 - iou + term_2 = term_2 * (1/(n*n)) + + term_3 = 0 + for i in range(m): + for j in range(m): + if y[i].sum() + y[j].sum() == 0: + continue + iou = jaccard(y[i], y[j]) + term_3 += 1 - iou + term_3 = term_3 * (1/(m*m)) + + D_ged = term_1 - term_2 - term_3 + + contours = {} + for o in range(n): + obs_y = y[o].float() + obs_y = obs_y.unsqueeze(0) + contours[f"obs_{o}"] = sitk.GetImageFromArray(obs_y) + for mm in range(m): + samp_pred = pred_y[mm].float() + samp_pred = samp_pred.unsqueeze(0) + contours[f"sample_{mm}"] = sitk.GetImageFromArray(samp_pred) + + vis.add_contour(contours, colormap=plt.cm.get_cmap("cool")) + vis.show() + + figure_path = "valid.png" + plt.savefig(figure_path, dpi=300) + plt.close("all") + + try: + self.logger.experiment.log_image(figure_path) + except AttributeError: + # Likely offline mode + pass + + self.log("GED", D_ged) + return D_ged + + +def main(args, config_json_path=None): + + pl.seed_everything(args.seed, workers=True) + + args.working_dir = Path(args.working_dir) + args.working_dir = args.working_dir.joinpath(args.experiment) + # args.default_root_dir = str(args.working_dir) + args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") + args.default_root_dir = str(args.fold_dir) + + comet_api_key = None + comet_workspace = None + comet_project = None + + if args.comet_api_key: + comet_api_key = args.comet_api_key + comet_workspace = args.comet_workspace + comet_project = args.comet_project + + if comet_api_key is None: + if "COMET_API_KEY" in os.environ: + comet_api_key = os.environ["COMET_API_KEY"] + if "COMET_WORKSPACE" in os.environ: + comet_workspace = os.environ["COMET_WORKSPACE"] + if "COMET_PROJECT" in os.environ: + comet_project = os.environ["COMET_PROJECT"] + + if comet_api_key is not None: + comet_logger = CometLogger( + api_key=comet_api_key, + workspace=comet_workspace, + project_name=comet_project, + experiment_name=args.experiment, + save_dir=args.working_dir, + offline=args.offline, + ) + if config_json_path: + comet_logger.experiment.log_code(config_json_path) + + dict_args = vars(args) + + data_module = LIDCDataModule(**dict_args) + + prob_unet = ProbUNet(**dict_args) + + if args.resume_from is not None: + trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) + else: + trainer = pl.Trainer.from_argparse_args(args) + + if comet_api_key is not None: + trainer.logger = comet_logger + + lr_monitor = LearningRateMonitor(logging_interval="step") + trainer.callbacks.append(lr_monitor) + + # Save the best model + checkpoint_callback = ModelCheckpoint( + monitor="GED", + dirpath=args.default_root_dir, + filename="probunet-{epoch:02d}-{GED:.2f}", + save_top_k=1, + mode="min", + ) + trainer.callbacks.append(checkpoint_callback) + + trainer.fit(prob_unet, data_module) + + +if __name__ == "__main__": + + args = None + config_json_path = None + if len(sys.argv) == 2: + # Check if JSON file parsed, if so read arguments from there... + if sys.argv[-1].endswith(".json"): + config_json_path = sys.argv[-1] + with open(config_json_path, "r") as f: + params = json.load(f) + args = [] + for key in params: + args.append(f"--{key}") + + if isinstance(params[key], list): + for list_val in params[key]: + args.append(str(list_val)) + else: + args.append(str(params[key])) + + arg_parser = ArgumentParser() + arg_parser = ProbUNet.add_model_specific_args(arg_parser) + arg_parser = LIDCDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--experiment", type=str, default="lidc", help="Name of experiment") + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + + main(arg_parser.parse_args(args), config_json_path=config_json_path) diff --git a/platipy/imaging/cnn/train_localise.py b/platipy/imaging/cnn/train_localise.py new file mode 100644 index 00000000..97b97f91 --- /dev/null +++ b/platipy/imaging/cnn/train_localise.py @@ -0,0 +1,140 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import json + +from pathlib import Path +from argparse import ArgumentParser + +import comet_ml # pylint: disable=unused-import +from pytorch_lightning.loggers import CometLogger + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint + +from platipy.imaging.cnn.localise_net import LocaliseUNet +from platipy.imaging.cnn.dataload import UNetDataModule + + +def main(args, config_json_path=None): + + pl.seed_everything(args.seed, workers=True) + + args.working_dir = Path(args.working_dir) + args.working_dir = args.working_dir.joinpath(args.experiment) + args.fold_dir = args.working_dir.joinpath(f"fold_{args.fold}") + args.default_root_dir = str(args.fold_dir) + args.num_sanity_val_steps = 0 + + comet_api_key = None + comet_workspace = None + comet_project = None + + if args.comet_api_key: + comet_api_key = args.comet_api_key + comet_workspace = args.comet_workspace + comet_project = args.comet_project + + if comet_api_key is None: + if "COMET_API_KEY" in os.environ: + comet_api_key = os.environ["COMET_API_KEY"] + if "COMET_WORKSPACE" in os.environ: + comet_workspace = os.environ["COMET_WORKSPACE"] + if "COMET_PROJECT" in os.environ: + comet_project = os.environ["COMET_PROJECT"] + + if comet_api_key is not None: + comet_logger = CometLogger( + api_key=comet_api_key, + workspace=comet_workspace, + project_name=comet_project, + experiment_name=args.experiment, + save_dir=args.working_dir, + offline=args.offline, + ) + if config_json_path: + comet_logger.experiment.log_code(config_json_path) + + dict_args = vars(args) + dict_args["validation_sampler"] = "batch" + + data_module = UNetDataModule(**dict_args) + + prob_unet = LocaliseUNet(**dict_args) + + if args.resume_from is not None: + trainer = pl.Trainer(resume_from_checkpoint=args.resume_from) + else: + trainer = pl.Trainer.from_argparse_args(args) + + if comet_api_key is not None: + trainer.logger = comet_logger + + # Save the best model + checkpoint_callback = ModelCheckpoint( + monitor="DSC", + dirpath=args.default_root_dir, + filename="localise-{epoch:02d}-{DSC:.2f}", + save_top_k=1, + mode="max", + ) + + trainer.callbacks.append(checkpoint_callback) + + trainer.fit(prob_unet, data_module) + + +if __name__ == "__main__": + + args = None + config_json_path = None + if len(sys.argv) == 2: + # Check if JSON file parsed, if so read arguments from there... + if sys.argv[-1].endswith(".json"): + config_json_path = sys.argv[-1] + with open(config_json_path, "r") as f: + params = json.load(f) + args = [] + for key in params: + print(key) + args.append(f"--{key}") + + if isinstance(params[key], list): + for list_val in params[key]: + args.append(str(list_val)) + else: + args.append(str(params[key])) + + arg_parser = ArgumentParser() + arg_parser = LocaliseUNet.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", nargs="+", type=int, default=[3, 3, 3]) + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + arg_parser.add_argument("--combine_observers", type=str, default="union") + + main(arg_parser.parse_args(args), config_json_path=config_json_path) diff --git a/platipy/imaging/cnn/unet.py b/platipy/imaging/cnn/unet.py new file mode 100644 index 00000000..e16c8d0b --- /dev/null +++ b/platipy/imaging/cnn/unet.py @@ -0,0 +1,270 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Parts of this work are derived from: +# https://github.com/stefanknegt/Probabilistic-Unet-Pytorch +# which is released under the Apache Licence 2.0 + +# pylint: disable=invalid-name + +import torch +from torch import nn + + +def conv_nd(ndims=2, **kwargs): + """Generate a 2D or 3D convolution + + Args: + ndims (int, optional): 2 or 3 dimensions. Defaults to 2. + + Raises: + NotImplementedError: Raised if ndims is not in 2 or 3 dimensions. + + Returns: + torch.nn.Conv: The convolution. + """ + + if ndims == 2: + return torch.nn.Conv2d(**kwargs) + elif ndims == 3: + return torch.nn.Conv3d(**kwargs) + + raise NotImplementedError("Only 2 or 3 dimensions are supported") + + +def dropout_nd(ndims=2, **kwargs): + """Get a 2D or 3D dropout layer + + Args: + ndims (int, optional): 2 or 3 dimensions. Defaults to 2. + + Raises: + NotImplementedError: Raised if ndims is not in 2 or 3 dimensions. + + Returns: + torch.nn.Dropout: The dropout layer + """ + + if ndims == 2: + return torch.nn.Dropout2d(**kwargs) + elif ndims == 3: + return torch.nn.Dropout3d(**kwargs) + + raise NotImplementedError("Only 2 or 3 dimensions are supported") + + +def init_weights(m): + if ( + isinstance(m, torch.nn.Conv2d) + or isinstance(m, torch.nn.ConvTranspose2d) + or isinstance(m, torch.nn.Conv3d) + or isinstance(m, torch.nn.ConvTranspose3d) + ): + torch.nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") + truncated_normal_(m.bias, mean=0, std=0.001) + +def l2_regularisation(m): + l2_reg = None + + for W in m.parameters(): + if l2_reg is None: + l2_reg = W.norm(2) + else: + l2_reg = l2_reg + W.norm(2) + return l2_reg + + +def init_zeros(m): + if ( + isinstance(m, torch.nn.Conv2d) + or isinstance(m, torch.nn.ConvTranspose2d) + or isinstance(m, torch.nn.Conv3d) + or isinstance(m, torch.nn.ConvTranspose3d) + ): + torch.nn.init.zeros_(m.weight) + truncated_normal_(m.bias, mean=0, std=0.1) + + +def truncated_normal_(tensor, mean=0, std=1): + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + + +def resize_down_func(scale=2, ndims=2): + """Returns function to resize the input to downsample + + Args: + scale (int, optional): The scale used to downsize. Defaults to 2. + ndims (int, optional): Number of dimensions (2 or 3). Defaults to 2. + + Returns: + function: The downsize function + """ + if ndims == 3: + return torch.nn.MaxPool3d(kernel_size=scale, stride=scale, padding=0) + elif ndims == 2: + return torch.nn.MaxPool2d(kernel_size=scale, stride=scale, padding=0) + + raise NotImplementedError() + + +def resize_up_func(in_channels, out_channels, scale=2, ndims=2): + """Return function to resize the input to upsample + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + scale (int, optional): The scale used to upsize. Defaults to 2. + ndims (int, optional): Number of dimensions (2 or 3). Defaults to 2. + + Raises: + NotImplementedError: Only supports 2d or 3d + + Returns: + function: The upsize function + """ + if ndims == 3: + return torch.nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size=scale, + stride=scale, + ) + elif ndims == 2: + return torch.nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=scale, + stride=scale, + ) + raise NotImplementedError() + + +class Conv(torch.nn.Module): + def __init__( + self, input_channels, output_channels, up_down_sample=0, dropout_probability=0.0, ndims=2 + ): + + super(Conv, self).__init__() + + self.pre_op = None + size_and_stride = abs(up_down_sample) + if up_down_sample < 0: + self.pre_op = resize_down_func(size_and_stride, ndims=ndims) + elif up_down_sample > 0: + self.pre_op = resize_up_func( + input_channels, output_channels, size_and_stride, ndims=ndims + ) + + layers = [] + layers.append( + conv_nd( + ndims=ndims, + in_channels=input_channels, + out_channels=output_channels, + kernel_size=3, + padding=1, + ) + ) + layers.append(nn.ReLU(inplace=True)) + layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) + layers.append( + conv_nd( + ndims=ndims, + in_channels=output_channels, + out_channels=output_channels, + kernel_size=3, + padding=1, + ) + ) + layers.append(dropout_nd(ndims=ndims, p=dropout_probability)) + layers.append(nn.ReLU(inplace=True)) + self.layers = nn.Sequential(*layers) + + self.layers.apply(init_weights) + + def forward(self, x, concat=None): + + if not self.pre_op is None: + x = self.pre_op(x) + + if not concat is None: + x = torch.cat([x, concat], 1) + + return self.layers(x) + + +class UNet(nn.Module): + def __init__( + self, + input_channels=1, + output_classes=2, + filters_per_layer=[64 * (2 ** x) for x in range(5)], + final_layer=True, + ndims=2, + dropout_probability=0.0 + ): + + super(UNet, self).__init__() + + self.encoder = nn.ModuleList() + for idx, layer_filters in enumerate(filters_per_layer): + input_filters = input_channels if idx == 0 else output_filters + output_filters = layer_filters + down_sample = 0 if idx == 0 else -2 + + self.encoder.append( + Conv(input_filters, output_filters, up_down_sample=down_sample, dropout_probability=dropout_probability, ndims=ndims) + ) + + reversed_filters = list(reversed(filters_per_layer)) + self.decoder = nn.ModuleList() + for idx, layer_filters in enumerate(reversed_filters): + + if idx == len(reversed_filters) - 1: + continue + + input_filters = layer_filters + output_filters = reversed_filters[idx + 1] + + self.decoder.append(Conv(input_filters, output_filters, up_down_sample=2, dropout_probability=dropout_probability, ndims=ndims)) + + self.final = None + if final_layer: + self.final = conv_nd( + ndims=ndims, + in_channels=filters_per_layer[0], + out_channels=output_classes, + kernel_size=1, + ) + + def forward(self, x): + + blocks = [] + for idx, enc in enumerate(self.encoder): + x = enc(x) + if idx != len(self.encoder) - 1: + blocks.append(x) + + for idx, dec in enumerate(self.decoder): + x = dec(x, concat=blocks[-idx - 1]) + + if self.final: + return self.final(x) + + return x diff --git a/platipy/imaging/cnn/utils.py b/platipy/imaging/cnn/utils.py new file mode 100644 index 00000000..b863b0c6 --- /dev/null +++ b/platipy/imaging/cnn/utils.py @@ -0,0 +1,201 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import SimpleITK as sitk + +from platipy.imaging.label.utils import get_union_mask, get_intersection_mask +from platipy.imaging.label.comparison import ( + compute_metric_dsc, + compute_metric_hd, + compute_metric_masd, +) + + +def get_contour_mask(masks, kernel=5): + """Returns a mask around the region where observer masks don't agree + + Args: + masks (list): List of observer masks (as sitk.Image) + kernel (int, optional): The size of the kernal to dilate the contour of. Defaults to 5. + + Returns: + sitk.Image: The resulting contour mask + """ + + if not hasattr(kernel, "__iter__"): + kernel = (kernel,) * 3 + + union_mask = get_union_mask(masks) + intersection_mask = get_intersection_mask(masks) + + union_mask = sitk.BinaryDilate(union_mask, np.abs(kernel).astype(int).tolist(), sitk.sitkBall) + intersection_mask = sitk.BinaryErode( + intersection_mask, np.abs(kernel).astype(int).tolist(), sitk.sitkBall + ) + + return union_mask - intersection_mask + + +def preprocess_image( + img, + spacing=[1, 1, 1], + crop_to_grid_size_xy=128, + intensity_scaling="window", + intensity_window=[-500, 500], +): + """Preprocess an image to prepare it for use in a CNN. + + Args: + img (sitk.Image): [description] + spacing (list, optional): [description]. Defaults to [1, 1, 1]. + crop_to_grid_size_xy (int|list, optional): Crop to the center grid of this size in x and y + direction. May be int value which will be use for both x and y size. Or a list containing + two int values for x and y. Defaults to 128. + intensity_scaling (str, optional): How to scale the intensity values. Should be one of + 'norm' (center mean and unit variance), 'window' (map window [min max] to [-1 1]), 'none' + (no intensity scaling applied). Defaults to "window". + intensity_window (list, optional): List with min and max values to be used when + intensity_scaling is 'window'. Not used otherwise. Defaults to [-500, 500]. + + Returns: + sitk.Image: The preprocessed image. + """ + + img = sitk.Cast(img, sitk.sitkFloat32) + if intensity_scaling == "norm": + img = sitk.Normalize(img) + elif intensity_scaling == "window": + img = sitk.IntensityWindowing( + img, + windowMinimum=intensity_window[0], + windowMaximum=intensity_window[1], + outputMinimum=-1.0, + outputMaximum=1.0, + ) + elif intensity_scaling != "none" and intensity_scaling is not None: + raise ValueError("intensity_scaling should be one of: 'norm', 'window', 'none'") + + new_size = sitk.VectorUInt32(3) + new_size[0] = int(img.GetSize()[0] * (img.GetSpacing()[0] / spacing[0])) + new_size[1] = int(img.GetSize()[1] * (img.GetSpacing()[1] / spacing[1])) + new_size[2] = int(img.GetSize()[2] * (img.GetSpacing()[2] / spacing[2])) + + if crop_to_grid_size_xy: + + if not hasattr(crop_to_grid_size_xy, "__iter__"): + crop_to_grid_size_xy = (crop_to_grid_size_xy,) * 2 + + if new_size[0] < crop_to_grid_size_xy[0]: + new_size[0] = crop_to_grid_size_xy[0] + + if new_size[1] < crop_to_grid_size_xy[1]: + new_size[1] = crop_to_grid_size_xy[1] + + img = sitk.Resample( + img, + new_size, + sitk.Transform(), + sitk.sitkLinear, + img.GetOrigin(), + spacing, + img.GetDirection(), + -1, + img.GetPixelID(), + ) + + if crop_to_grid_size_xy: + center_x = img.GetSize()[0] / 2 + x_from = int(center_x - crop_to_grid_size_xy[0] / 2) + x_to = x_from + crop_to_grid_size_xy[0] + + center_y = img.GetSize()[1] / 2 + y_from = int(center_y - crop_to_grid_size_xy[1] / 2) + y_to = y_from + crop_to_grid_size_xy[1] + + img = img[x_from:x_to, y_from:y_to, :] + + sitk.WriteImage(img, "tmp.nii.gz") + + return img + + +def resample_mask_to_image(img, mask): + """Repsample a mask to the space of the image supplied. + + Args: + img (sitk.Image): Image to sample to space of. + mask (sitk.Image): Mask to resample. + + Returns: + sitk.Image: The resampled mask. + """ + + return sitk.Resample( + mask, + img, + sitk.Transform(), + sitk.sitkNearestNeighbor, + 0, + mask.GetPixelID(), + ) + + +def postprocess_mask(pred): + """Perform postprocessing on a generated auto-segmentation + + Args: + pred (sitk.Image): The predicted mask + + Returns: + sitk.Image: The postprocessed mask + """ + + # Take only the largest componenet + labelled_image = sitk.ConnectedComponent(pred) + label_shape_filter = sitk.LabelShapeStatisticsImageFilter() + label_shape_filter.Execute(labelled_image) + label_indices = label_shape_filter.GetLabels() + voxel_counts = [label_shape_filter.GetNumberOfPixels(i) for i in label_indices] + if len(voxel_counts) > 0: + largest_component_label = label_indices[np.argmax(voxel_counts)] + largest_component_image = labelled_image == largest_component_label + pred = sitk.Cast(largest_component_image, sitk.sitkUInt8) + + # Fill any holes in the structure + pred = sitk.BinaryMorphologicalClosing(pred, (5, 5, 5)) + pred = sitk.BinaryFillhole(pred) + + return pred + + +def get_metrics(target, pred): + + result = {} + result["DSC"] = compute_metric_dsc(target, pred) + + target_pixels = sitk.GetArrayFromImage(target) + pred_pixels = sitk.GetArrayFromImage(pred) + + if pred_pixels.max() == 0 and target_pixels.max() == 0: + result["HD"] = 0 + result["ASD"] = 0 + elif pred_pixels.max() == 0 or target_pixels.max() == 0 or pred_pixels.min() == 1 or target_pixels.min() == 1: + result["HD"] = 1000 + result["ASD"] = 100 + else: + result["HD"] = compute_metric_hd(target, pred, auto_crop=False) + result["ASD"] = compute_metric_masd(target, pred, auto_crop=False) + + return result diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index b6fd7455..d8fca7ec 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -15,8 +15,18 @@ from abc import ABC, abstractmethod from collections.abc import Iterable import random +import logging +import sys + +from pathlib import Path + +from argparse import ArgumentParser import SimpleITK as sitk +import numpy as np + +import matplotlib.pyplot as plt +from platipy.imaging import ImageVisualiser from platipy.imaging.generation.dvf import ( generate_field_shift, @@ -29,9 +39,15 @@ from platipy.imaging.registration.utils import apply_transform +from platipy.imaging.utils.lung import detect_holes +from platipy.imaging.label.utils import get_union_mask +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi -def apply_augmentation(image, augmentation, masks=[]): +logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) +logger = logging.getLogger(__name__) + +def apply_augmentation(image, augmentation, context_map=None, masks=[]): if not isinstance(image, sitk.Image): raise AttributeError("image should be a SimpleITK.Image") @@ -44,23 +60,26 @@ def apply_augmentation(image, augmentation, masks=[]): "DeformableAugment's" ) - transforms = [] + # transforms = [] + transform = None dvf = None for aug in augmentation: - if not isinstance(aug, DeformableAugment): raise AttributeError("Each augmentation must be of type DeformableAugment") + logger.debug(str(aug)) tfm, field = aug.augment() - transforms.append(tfm) + # transforms.append(tfm) if dvf is None: dvf = field + transform = tfm else: dvf += field + transform = sitk.CompositeTransform([transform, tfm]) - transform = sitk.CompositeTransform(transforms) - del transforms + # transform = sitk.CompositeTransform(transforms) + # del transforms image_deformed = apply_transform( image, @@ -71,59 +90,54 @@ def apply_augmentation(image, augmentation, masks=[]): masks_deformed = [] for mask in masks: - masks_deformed.append( - apply_transform( - mask, transform=transform, default_value=0, interpolator=sitk.sitkNearestNeighbor - ) + def_mask = apply_transform( + mask, + transform=transform, + default_value=0, + interpolator=sitk.sitkNearestNeighbor, + ) + + def_mask = sitk.BinaryMorphologicalClosing(def_mask, [3, 3, 3]) + + masks_deformed.append(def_mask) + + cmap_deformed = None + if context_map is not None: + cmap_deformed = apply_transform( + context_map, + transform=transform, + default_value=0, + interpolator=sitk.sitkNearestNeighbor, ) if masks: - return image_deformed, masks_deformed, dvf - - return image_deformed, dvf - - -def generate_random_augmentation(ct_image, masks): - - random.shuffle(masks) - # mask_count = len(masks) - # masks = masks[: random.randint(2, 5)] - - # print(len(masks)) - augmentation_types = [ - { - "class": ShiftAugment, - "args": {"vector_shift": [(-10, 10), (10, 10), (-10, 10)], "gaussian_smooth": (3, 5)}, - }, - { - "class": ContractAugment, - "args": { - "vector_contract": [(0, 10), (0, 10), (0, 10)], - "gaussian_smooth": (3, 5), - "bone_mask": True, - }, - }, - { - "class": ExpandAugment, - "args": { - "vector_expand": [(0, 10), (0, 10), (0, 10)], - "gaussian_smooth": (3, 5), - "bone_mask": True, - }, - }, - ] + return image_deformed, cmap_deformed, masks_deformed, dvf + + return image_deformed, cmap_deformed, dvf + +def generate_random_augmentation(ct_image, masks, augmentation_types): augmentation = [] + + probabilities = [a["probability"] for a in augmentation_types] + prob_total = sum(probabilities) + prob_none = 1.0 - prob_total + if prob_none < 0: + prob_none = 0 + for mask in masks: - aug = random.choice(augmentation_types) + aug = random.choices( + augmentation_types + [None], weights=probabilities + [prob_none] + )[0] + + if aug is None: + continue aug_class = aug["class"] aug_args = {} for arg in aug["args"]: - value = aug["args"][arg] if isinstance(value, list): - # Randomly sample for each dim result = [] for rng in value: @@ -144,20 +158,17 @@ def generate_random_augmentation(ct_image, masks): class DeformableAugment(ABC): @abstractmethod def augment(self): - # return deformation pass class ShiftAugment(DeformableAugment): def __init__(self, mask, vector_shift=(10, 10, 10), gaussian_smooth=5): - self.mask = mask self.vector_shift = vector_shift self.gaussian_smooth = gaussian_smooth def augment(self): - _, transform, dvf = generate_field_shift( self.mask, self.vector_shift, @@ -165,17 +176,20 @@ def augment(self): ) return transform, dvf + def __str__(self): + return f"Shift with vector: {self.vector_shift}, gauss: {self.gaussian_smooth}" -class ExpandAugment(DeformableAugment): - def __init__(self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False): +class ExpandAugment(DeformableAugment): + def __init__( + self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False + ): self.mask = mask self.vector_expand = vector_expand self.gaussian_smooth = gaussian_smooth self.bone_mask = bone_mask def augment(self): - _, transform, dvf = generate_field_expand( self.mask, bone_mask=self.bone_mask, @@ -185,21 +199,308 @@ def augment(self): return transform, dvf + def __str__(self): + return ( + f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}" + ) -class ContractAugment(DeformableAugment): - def __init__(self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False): +class ContractAugment(DeformableAugment): + def __init__( + self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False + ): self.mask = mask - self.contract = [int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())] + self.vector_contract = [ + int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing()) + ] self.gaussian_smooth = gaussian_smooth self.bone_mask = bone_mask def augment(self): - _, transform, dvf = generate_field_expand( self.mask, bone_mask=self.bone_mask, - expand=self.contract, + expand=self.vector_contract, gaussian_smooth=self.gaussian_smooth, ) return transform, dvf + + def __str__(self): + return f"Contract with vector: {self.vector_contract}, smooth: {self.gaussian_smooth}" + + +def augment_data(args): + random.seed(args.seed) + + augmentation_types = [] + + if args.enable_shift: + augmentation_types.append( + { + "class": ShiftAugment, + "args": { + "vector_shift": [ + tuple(args.shift_x_range), + tuple(args.shift_y_range), + tuple(args.shift_z_range), + ], + "gaussian_smooth": tuple(args.shift_smooth_range), + }, + "probability": args.shift_probability, + } + ) + + if args.enable_contract: + augmentation_types.append( + { + "class": ContractAugment, + "args": { + "vector_contract": [ + tuple(args.contract_x_range), + tuple(args.contract_y_range), + tuple(args.contract_z_range), + ], + "gaussian_smooth": tuple(args.contract_smooth_range), + "bone_mask": args.contract_bone_mask, + }, + "probability": args.contract_probability, + } + ) + + if args.enable_expand: + augmentation_types.append( + { + "class": ExpandAugment, + "args": { + "vector_expand": [ + tuple(args.expand_x_range), + tuple(args.expand_y_range), + tuple(args.expand_z_range), + ], + "gaussian_smooth": tuple(args.expand_smooth_range), + "bone_mask": args.expand_bone_mask, + }, + "probability": args.expand_probability, + } + ) + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + + cases = [ + p.name.replace(".nii.gz", "") + for p in data_dir.glob(args.case_glob) + if not p.name.startswith(".") + ] + cases.sort() + + data = { + case: { + "image": data_dir.joinpath(args.image_glob.format(case=case)), + "context_map": data_dir.joinpath(args.context_map_glob.format(case=case)), + "label": [ + i + for sl in [ + list(data_dir.glob(lg.format(case=case))) for lg in args.label_glob + ] + for i in sl + ], + } + for case in cases + } + + for case in cases: + logger.info(f"Augmenting for case: {case}") + + ct_image_original = sitk.ReadImage(str(data[case]["image"])) + + cmap_original = None + if data[case]["context_map"]: + cmap_original = sitk.ReadImage(str(data[case]["context_map"])) + + # Get list of structures to generate augmentations off + logger.debug("Collecting structures") + all_masks = [] + all_names = [] + for structure_path in data[case]["label"]: + mask = sitk.ReadImage(str(structure_path)) + + all_masks.append(mask) + all_names.append(structure_path.name.replace(".nii.gz", "")) + + logger.debug("Cropping to regions around all structures") + union_mask = get_union_mask(all_masks) + size, index = label_to_roi(union_mask, expansion_mm=[25, 25, 25]) + ct_image = crop_to_roi(ct_image_original, size, index) + + for m, mask in enumerate(all_masks): + all_masks[m] = crop_to_roi(mask, size, index) + + if args.enable_fill_holes: + logger.debug("Finding holes") + label_image, labels = detect_holes(ct_image) + + # Generate x random augmentations per case + for i in range(args.augmentations_per_case): + logger.debug(f"Generating augmentation {i}") + + ct_image = sitk.ReadImage(str(data[case]["image"])) + ct_image = crop_to_roi(ct_image, size, index) + + cmap = None + if data[case]["context_map"]: + cmap = sitk.ReadImage(str(data[case]["context_map"])) + cmap = crop_to_roi(cmap, size, index) + + if args.enable_fill_holes: + logger.debug("Filling holes") + + for label in labels[1:]: # Skip first hole since likely air around body + if random.random() > args.fill_probability: + continue + + hole = label_image == label["label"] + hole_dilate = sitk.BinaryDilate(hole, (2, 2, 2), sitk.sitkBall) + contour_points = sitk.BinaryContour(hole_dilate) + fill_value = np.median( + sitk.GetArrayFromImage(ct_image)[ + sitk.GetArrayFromImage(contour_points) == 1 + ] + ) + + ct_arr = sitk.GetArrayFromImage(ct_image) + ct_arr[sitk.GetArrayFromImage(hole_dilate) == 1] = fill_value + ct_filled = sitk.GetImageFromArray(ct_arr) + ct_filled.CopyInformation(ct_image) + + ct_image = ct_filled + + augmented_case_path = output_dir.joinpath(case, f"augment_{i}") + augmented_case_path.mkdir(exist_ok=True, parents=True) + + logger.debug("Generating random augmentations") + augmentation = generate_random_augmentation( + ct_image, all_masks, augmentation_types + ) + + dvf = None + augmented_cmap = None + + if len(augmentation) == 0: + logger.debug( + "No augmentations generated, generated image won't differ from original" + ) + + augmented_image = ct_image + augmented_masks = all_masks + else: + logger.debug("Applying augmentation") + ( + augmented_image, + augmented_cmap, + augmented_masks, + dvf, + ) = apply_augmentation( + ct_image, augmentation, context_map=cmap, masks=all_masks + ) + + # Save off image + augmented_image_path = augmented_case_path.joinpath("CT.nii.gz") + ct_image_original[ + index[0] : index[0] + size[0], + index[1] : index[1] + size[1], + index[2] : index[2] + size[2], + ] = augmented_image + sitk.WriteImage(ct_image_original, str(augmented_image_path)) + + # Save off context map if we have one + if augmented_cmap: + augmented_cmap_path = augmented_case_path.joinpath("context_map.nii.gz") + cmap_original[ + index[0] : index[0] + size[0], + index[1] : index[1] + size[1], + index[2] : index[2] + size[2], + ] = augmented_cmap + sitk.WriteImage(cmap_original, str(augmented_cmap_path)) + + vis = ImageVisualiser(image=ct_image, figure_size_in=6) + vis.add_comparison_overlay(augmented_image) + if dvf is not None: + vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) + for mask_name, mask, augmented_mask in zip( + all_names, all_masks, augmented_masks + ): + vis.add_contour( + {f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask} + ) + + logger.debug(f"Applying augmentation to mask: {mask_name}") + augmented_mask_path = augmented_case_path.joinpath( + f"{mask_name}.nii.gz" + ) + augmented_mask = sitk.Resample( + augmented_mask, + ct_image_original, + sitk.Transform(), + sitk.sitkNearestNeighbor, + ) + sitk.WriteImage(augmented_mask, str(augmented_mask_path)) + + fig = vis.show() + + figure_path = augmented_case_path.joinpath("aug.png") + fig.savefig(figure_path, bbox_inches="tight") + plt.close() + + +if __name__ == "__main__": + arg_parser = ArgumentParser() + arg_parser.add_argument( + "--seed", type=int, default=42, help="an integer to use as seed" + ) + arg_parser.add_argument("--data_dir", type=str, default="./data") + arg_parser.add_argument("--output_dir", type=str, default="./augment") + arg_parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") + arg_parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") + arg_parser.add_argument( + "--label_glob", nargs="+", type=str, default="labels/{case}_*.nii.gz" + ) + arg_parser.add_argument("--context_map_glob", type=str, default=None) + arg_parser.add_argument( + "--augmentations_per_case", + type=int, + default=10, + help="How many augmented images per case to generate", + ) + + arg_parser.add_argument("--enable_shift", type=bool, default=True) + arg_parser.add_argument("--shift_x_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_y_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_z_range", nargs="+", type=int, default=[-10, 10]) + arg_parser.add_argument("--shift_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument("--shift_probability", type=float, default=0.5) + + arg_parser.add_argument("--enable_expand", type=bool, default=True) + arg_parser.add_argument("--expand_x_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--expand_y_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--expand_z_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument( + "--expand_smooth_range", nargs="+", type=int, default=[3, 5] + ) + arg_parser.add_argument("--expand_bone_mask", type=bool, default=True) + arg_parser.add_argument("--expand_probability", type=float, default=0.5) + + arg_parser.add_argument("--enable_contract", type=bool, default=True) + arg_parser.add_argument("--contract_x_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--contract_y_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument("--contract_z_range", nargs="+", type=int, default=[0, 10]) + arg_parser.add_argument( + "--contract_smooth_range", nargs="+", type=int, default=[3, 5] + ) + arg_parser.add_argument("--contract_bone_mask", type=bool, default=True) + arg_parser.add_argument("--contract_probability", type=float, default=0.5) + + arg_parser.add_argument("--enable_fill_holes", type=bool, default=True) + arg_parser.add_argument("--fill_probability", type=float, default=0.2) + + augment_data(arg_parser.parse_args()) diff --git a/platipy/imaging/generation/dvf.py b/platipy/imaging/generation/dvf.py index 9d0b84cc..efd9f7b9 100644 --- a/platipy/imaging/generation/dvf.py +++ b/platipy/imaging/generation/dvf.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import numpy as np import SimpleITK as sitk @@ -25,13 +26,17 @@ fast_symmetric_forces_demons_registration, ) +from platipy.imaging.label.utils import get_com +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi -def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth=5): +logger = logging.getLogger(__name__) + +def generate_field_shift(mask, vector_shift=(10, 10, 10), gaussian_smooth=5): """ Shifts (moves) a structure defined using a binary mask. Args: - mask_image ([SimpleITK.Image]): The binary mask to shift. + mask ([SimpleITK.Image]): The binary mask to shift. vector_shift (tuple, optional): The displacement vector applied to the entire binary mask. Convention: (+/-, +/-, +/-) = (sup/inf, post/ant, left/right) shift. @@ -45,9 +50,26 @@ def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth= [SimpleITK.DisplacementFieldTransform]: The transform representing the shift. [SimpleITK.Image]: The displacement vector field representing the shift. """ + + mask_full = mask + + roi_expand = [x + 5 for x in vector_shift] + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] + + # Make sure the expansion meets a minimum size (1cm) + roi_expand = [max(e, 10) for e in roi_expand] + size, index = label_to_roi(mask, expansion_mm=roi_expand) + mask = crop_to_roi(mask, size, index) + # Define array # Used for image array manipulations - mask_image_arr = sitk.GetArrayFromImage(mask_image) + mask_image_arr = sitk.GetArrayFromImage(mask) # The template deformation field # Used to generate transforms @@ -56,29 +78,28 @@ def generate_field_shift(mask_image, vector_shift=(10, 10, 10), gaussian_smooth= dvf_template = sitk.GetImageFromArray(dvf_arr) # Copy image information - dvf_template.CopyInformation(mask_image) + dvf_template.CopyInformation(mask) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) - mask_image_shift = apply_transform( - mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + mask_shift = apply_transform( + mask, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) - dvf_template = sitk.Mask(dvf_template, mask_image | mask_image_shift) + dvf_template = sitk.Mask(dvf_template, mask | mask_shift) # smooth if np.any(gaussian_smooth): - - if not hasattr(gaussian_smooth, "__iter__"): - gaussian_smooth = (gaussian_smooth,) * 3 - dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + # Resample back to original image + dvf_template = sitk.Resample(dvf_template, mask_full) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) - mask_image_shift = apply_transform( - mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + + mask_shift = apply_transform( + mask_full, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) - return mask_image_shift, dvf_tfm, dvf_template + return mask_shift, dvf_tfm, dvf_template def generate_field_asymmetric_contract( @@ -228,32 +249,49 @@ def generate_field_expand( dilation kernel. Args: - mask ([SimpleITK.Image]): The binary mask to expand. - bone_mask ([SimpleITK.Image, optional]): A binary mask defining regions where we expect - restricted deformations. - vector_asymmetric_extend (int |tuple, optional): The expansion vector applied to the entire - binary mask. - Convention: (z,y,x) size of expansion kernel. - Defined in millimetres. - Defaults to 3. + mask (SimpleITK.Image): The binary mask to expand. + bone_mask (SimpleITK.Image, optional): A binary mask defining regions where we expect + restricted deformations. + expand (int |tuple, optional): The expansion vector applied to the entire binary mask. + Convention: (z,y,x) size of expansion kernel. + Defined in millimetres. + Defaults to 3. gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the - deformation vector field. Defaults to 5. + deformation vector field. Defaults to 5. Returns: - [SimpleITK.Image]: The binary mask following the expansion. - [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion. - [SimpleITK.Image]: The displacement vector field representing the expansion. + SimpleITK.Image: The binary mask following the expansion. + SimpleITK.DisplacementFieldTransform: The transform representing the expansion. + SimpleITK.Image: The displacement vector field representing the expansion. """ + mask_full = mask + + if not hasattr(expand, "__iter__"): + expand = (expand,) * 3 + + roi_expand = expand + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + roi_expand = [x + y for x, y in zip(roi_expand, gaussian_smooth)] + + # Make sure the expansion meets a minimum size (1cm) + roi_expand = [max(e, 10) for e in roi_expand] + + size, index = label_to_roi(mask, expansion_mm=roi_expand) + mask = crop_to_roi(mask, size, index) + if bone_mask is not False: + bone_mask = sitk.Resample(bone_mask, mask, sitk.Transform(), sitk.sitkNearestNeighbor) mask_original = mask + bone_mask else: mask_original = mask # Use binary erosion to create a smaller volume - if not hasattr(expand, "__iter__"): - expand = (expand,) * 3 - expand = np.array(expand) # Convert voxels to millimetres @@ -265,17 +303,17 @@ def generate_field_expand( # If all negative: erode if np.all(np.array(expand) <= 0): - print("All factors negative: shrinking only.") + logger.debug("All factors negative: shrinking only.") mask_expand = sitk.BinaryErode(mask, np.abs(expand).astype(int).tolist(), sitk.sitkBall) # If all positive: dilate elif np.all(np.array(expand) >= 0): - print("All factors positive: expansion only.") + logger.debug("All factors positive: expansion only.") mask_expand = sitk.BinaryDilate(mask, np.abs(expand).astype(int).tolist(), sitk.sitkBall) # Otherwise: sequential operations else: - print("Mixed factors: shrinking and expansion.") + logger.debug("Mixed factors: shrinking and expansion.") expansion_kernel = expand * (expand > 0) shrink_kernel = expand * (expand < 0) @@ -309,16 +347,14 @@ def generate_field_expand( # smooth if np.any(gaussian_smooth): - - if not hasattr(gaussian_smooth, "__iter__"): - gaussian_smooth = (gaussian_smooth,) * 3 - dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + # Resample back to original image + dvf_template = sitk.Resample(dvf_template, mask_full) dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) mask_symmetric_expand = apply_transform( - mask, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + mask_full, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor ) return mask_symmetric_expand, dvf_tfm, dvf_template @@ -413,3 +449,161 @@ def generate_field_radial_bend( ) return reference_image_bend, dvf_tfm, dvf_template + + +def expand_mask_towards_target( + mask_image, target_image, expand_mag=20, gaussian_smooth=5, dvf_overlap_into_mask=3 +): + """Generate a deformation vector field to expand a mask towards a target mask. Can be useful to + manipulate structures for augmentation of fail cases for automated contour QA work. + + Args: + mask_image (sitk.Image): The mask of the structure of manipulate + target_image (sitk.Image): The mask of the target structure to expand towards + expand_mag (int, optional): The magnitude of the expansion in mm. Defaults to 20. + gaussian_smooth (int, optional): Scale of a Gaussian kernel used to smooth the + deformation vector field.. Defaults to 5. + dvf_overlap_into_mask (int, optional): Defines how much overlap the deformation field + into the mask image. Effects how much of the structure is deformed. Defaults to 3. + + Returns: + [SimpleITK.Image]: The binary mask following the expansion. + [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion. + [SimpleITK.Image]: The displacement vector field representing the expansion. + """ + + # Remove any potential overlap between the target and the mask + target_image = sitk.MaskNegated(target_image, mask_image) + + # Determine the vector to expand the mask towards + mask_com = get_com(mask_image, as_int=False, real_coords=True) + target_com = get_com(target_image, as_int=False, real_coords=True) + + expand_vec = np.array([p - q for p, q in zip(target_com, mask_com)]) + expand_vec = expand_vec / np.linalg.norm(expand_vec) + + mask_image_arr = sitk.GetArrayFromImage(mask_image) + + # Compute the distance map from the target to every other voxel + dist_map = sitk.SignedMaurerDistanceMap(target_image, squaredDistance=False) + dist_map_arr = sitk.GetArrayFromImage(dist_map) + dist_map_arr[dist_map_arr < 0] = 0 + + # Manipulate the distance map so that only voxel within the range of dvf_overlap_into_mask are + # kept + dist_from_mask_to_target = dist_map_arr[mask_image_arr > 0].min() + max_mask_dist = dist_map_arr[mask_image_arr > 0].max() + dist_map_arr[dist_map_arr > max_mask_dist] = max_mask_dist + dist_map_arr[dist_map_arr > dist_from_mask_to_target + dvf_overlap_into_mask] = ( + dist_from_mask_to_target + dvf_overlap_into_mask + ) + + dvf_weight = np.zeros(dist_map_arr.shape) + dvf_weight[dist_map_arr < dist_from_mask_to_target + dvf_overlap_into_mask] = 1 + dvf_weight = np.tile(np.expand_dims(dvf_weight, axis=3), [1, 1, 1, 3]) + + # The template deformation field + # Used to generate transforms + dvf_arr = np.zeros(mask_image_arr.shape + (3,)) + dvf_arr = dvf_arr - np.array([[[expand_vec * expand_mag]]]) + + # Weight the deformation vectors by the manipulated distance map + dvf_arr = dvf_arr * dvf_weight + dvf_template = sitk.GetImageFromArray(dvf_arr) + + # Copy image information + dvf_template.CopyInformation(mask_image) + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + + dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) + + mask_image_expanded = apply_transform( + mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + ) + + return mask_image_expanded, dvf_tfm, dvf_template + + +def contract_mask_away_from_target( + mask_image, target_image, contract_mag=20, gaussian_smooth=5 +): + """Generate a deformation vector field to contract a mask away from a target mask. Can be useful to + manipulate structures for augmentation of fail cases for automated contour QA work. + + Args: + mask_image (sitk.Image): The mask of the structure of manipulate + target_image (sitk.Image): The mask of the target structure to expand towards + contract_mag (int, optional): The magnitude of the contraction in mm. Defaults to 20. + gaussian_smooth (int, optional): Scale of a Gaussian kernel used to smooth the + deformation vector field.. Defaults to 5. + + Returns: + [SimpleITK.Image]: The binary mask following the contraction. + [SimpleITK.DisplacementFieldTransform]: The transform representing the contraction. + [SimpleITK.Image]: The displacement vector field representing the contraction. + """ + + # Remove any potential overlap between the target and the mask + target_image = sitk.MaskNegated(target_image, mask_image) + dvf_overlap_into_mask = contract_mag + 5 + + # Determine the vector to expand the mask towards + mask_com = get_com(mask_image, as_int=False, real_coords=True) + target_com = get_com(target_image, as_int=False, real_coords=True) + + expand_vec = np.array([q - p for p, q in zip(target_com, mask_com)]) + expand_vec = expand_vec / np.linalg.norm(expand_vec) + + mask_image_arr = sitk.GetArrayFromImage(mask_image) + + # Compute the distance map from the target to every other voxel + dist_map = sitk.SignedMaurerDistanceMap(target_image, squaredDistance=False) + dist_map_arr = sitk.GetArrayFromImage(dist_map) + dist_map_arr[dist_map_arr < 0] = 0 + + # Manipulate the distance map so that only voxel within the range of dvf_overlap_into_mask are + # kept + dist_from_mask_to_target = dist_map_arr[mask_image_arr > 0].min() + max_mask_dist = dist_map_arr[mask_image_arr > 0].max() + dist_map_arr[dist_map_arr > max_mask_dist] = max_mask_dist + dist_map_arr[dist_map_arr > dist_from_mask_to_target + dvf_overlap_into_mask] = ( + dist_from_mask_to_target + dvf_overlap_into_mask + ) + + dvf_weight = np.zeros(dist_map_arr.shape) + dvf_weight[dist_map_arr < dist_from_mask_to_target + dvf_overlap_into_mask] = 1 + dvf_weight = np.tile(np.expand_dims(dvf_weight, axis=3), [1, 1, 1, 3]) + + # The template deformation field + # Used to generate transforms + dvf_arr = np.zeros(mask_image_arr.shape + (3,)) + dvf_arr = dvf_arr - np.array([[[expand_vec * contract_mag]]]) + + # Weight the deformation vectors by the manipulated distance map + dvf_arr = dvf_arr * dvf_weight + dvf_template = sitk.GetImageFromArray(dvf_arr) + + # Copy image information + dvf_template.CopyInformation(mask_image) + + if np.any(gaussian_smooth): + + if not hasattr(gaussian_smooth, "__iter__"): + gaussian_smooth = (gaussian_smooth,) * 3 + + dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) + + dvf_tfm = sitk.DisplacementFieldTransform(sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) + + mask_image_expanded = apply_transform( + mask_image, transform=dvf_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor + ) + + return mask_image_expanded, dvf_tfm, dvf_template + diff --git a/platipy/imaging/label/comparison.py b/platipy/imaging/label/comparison.py index b59981fa..e154b999 100644 --- a/platipy/imaging/label/comparison.py +++ b/platipy/imaging/label/comparison.py @@ -74,8 +74,8 @@ def compute_surface_dsc(label_a, label_b, tau=3.0): def compute_surface_metrics(label_a, label_b, verbose=False): """Compute surface distance metrics between two labels. Surface metrics computed are: - hausdorffDistance, meanSurfaceDistance, medianSurfaceDistance, maximumSurfaceDistance, - sigmaSurfaceDistance, surfaceDSC + hausdorffDistance, hausdorffDistance95, meanSurfaceDistance, medianSurfaceDistance, + maximumSurfaceDistance, sigmaSurfaceDistance, surfaceDSC Args: label_a (sitk.Image): A mask to compare @@ -95,8 +95,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False): std_sd_list = [] median_sd_list = [] num_points = [] - for (la, lb) in ((label_a, label_b), (label_b, label_a)): - + for la, lb in ((label_a, label_b), (label_b, label_a)): label_intensity_stat = sitk.LabelIntensityStatisticsImageFilter() reference_distance_map = sitk.Abs( sitk.SignedMaurerDistanceMap( @@ -118,6 +117,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False): mean_surf_dist = np.dot(mean_sd_list, num_points) / np.sum(num_points) max_surf_dist = np.max(max_sd_list) + hd_95 = np.percentile(max_sd_list, 95) std_surf_dist = np.sqrt( np.dot( num_points, @@ -131,6 +131,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False): result = {} result["hausdorffDistance"] = hd + result["hausdorffDistance95"] = hd_95 result["meanSurfaceDistance"] = mean_surf_dist result["medianSurfaceDistance"] = median_surf_dist result["maximumSurfaceDistance"] = max_surf_dist @@ -294,8 +295,7 @@ def compute_metric_masd(label_a, label_b, auto_crop=True): mean_sd_list = [] num_points = [] - for (la, lb) in ((label_a, label_b), (label_b, label_a)): - + for la, lb in ((label_a, label_b), (label_b, label_a)): label_intensity_stat = sitk.LabelIntensityStatisticsImageFilter() reference_distance_map = sitk.Abs( sitk.SignedMaurerDistanceMap( @@ -364,7 +364,6 @@ def compute_apl(label_ref, label_test, distance_threshold_mm=3): # iterate over each slice for i in range(n_slices): - if ( sitk.GetArrayViewFromImage(label_ref)[i].sum() + sitk.GetArrayViewFromImage(label_test)[i].sum() diff --git a/platipy/imaging/label/utils.py b/platipy/imaging/label/utils.py index f3e8138f..ae5ef957 100644 --- a/platipy/imaging/label/utils.py +++ b/platipy/imaging/label/utils.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +from pathlib import Path + import SimpleITK as sitk import numpy as np @@ -284,3 +287,63 @@ def binary_decode_image(binary_encoded_img): continue return structure_list + + +def get_union_mask(masks): + """Get the union mask + + Args: + masks (list|dict): A list or dictionary of masks given as SimpleITK.Image or path to mask + file. + + Raises: + ValueError: Raised if masks provided is empty. + + Returns: + SimpleITK.Image: The union mask + """ + + if isinstance(masks, dict): + masks = [masks[k] for k in masks] + + if len(masks) == 0: + raise ValueError("Masks must not be empty") + + if isinstance(masks[0], (str, Path)): + masks = [sitk.ReadImage(str(m)) for m in masks] + + union_mask = copy.copy(masks[0]) + for mask in masks[1:]: + union_mask += mask + + return sitk.Cast(union_mask > 0, sitk.sitkUInt8) + + +def get_intersection_mask(masks): + """Get the intersection mask + + Args: + masks (list|dict): A list or dictionary of masks given as SimpleITK.Image or path to mask + file. + + Raises: + ValueError: Raised if masks provided is empty. + + Returns: + SimpleITK.Image: The intersection mask + """ + + if isinstance(masks, dict): + masks = [masks[k] for k in masks] + + if len(masks) == 0: + raise ValueError("Masks must not be empty") + + if isinstance(masks[0], (str, Path)): + masks = [sitk.ReadImage(str(m)) for m in masks] + + intersection_mask = copy.copy(masks[0]) + for mask in masks[1:]: + intersection_mask += mask + + return sitk.Cast(intersection_mask == len(masks), sitk.sitkUInt8) diff --git a/platipy/imaging/tests/test_probunet.py b/platipy/imaging/tests/test_probunet.py new file mode 100644 index 00000000..e39b3996 --- /dev/null +++ b/platipy/imaging/tests/test_probunet.py @@ -0,0 +1,216 @@ +# Copyright 2021 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=redefined-outer-name,missing-function-docstring + +from argparse import ArgumentParser + +import pytest + +import pytorch_lightning as pl +from platipy.imaging.cnn.train import main, ProbUNet, UNetDataModule +from platipy.imaging.cnn.pseudo_generator import generate_pseudo_data + + +@pytest.fixture +def trainer_arg_parser(): + + generate_pseudo_data() + + arg_parser = ArgumentParser() + arg_parser = ProbUNet.add_model_specific_args(arg_parser) + arg_parser = UNetDataModule.add_model_specific_args(arg_parser) + arg_parser = pl.Trainer.add_argparse_args(arg_parser) + arg_parser.add_argument( + "--config", type=str, default=None, help="JSON file with parameters to load" + ) + arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument("--experiment", type=str, default="default", help="Name of experiment") + arg_parser.add_argument("--working_dir", type=str, default="./working") + arg_parser.add_argument("--num_observers", type=int, default=5) + arg_parser.add_argument("--spacing", nargs="+", type=float, default=[1, 1, 1]) + arg_parser.add_argument("--offline", type=bool, default=False) + arg_parser.add_argument("--comet_api_key", type=str, default=None) + arg_parser.add_argument("--comet_workspace", type=str, default=None) + arg_parser.add_argument("--comet_project", type=str, default=None) + arg_parser.add_argument("--resume_from", type=str, default=None) + return arg_parser + + +def test_prob_unet_2d_elbo(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_2d_elbo", + "--num_workers", + "1", + "--limit_train_batches", + "0.025", + "--loss_type", + "elbo", + "--prob_type", + "prob", + "--max_epochs", + "1", + "--ndims", + "2", + "--filters_per_layer", + "2", + "4", + ] + ) + + main(args) + + +def test_prob_unet_3d_elbo(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_3d_elbo", + "--num_workers", + "1", + "--limit_train_batches", + "0.1", + "--loss_type", + "elbo", + "--prob_type", + "prob", + "--max_epochs", + "1", + "--ndims", + "3", + "--filters_per_layer", + "2", + "4", + "--batch_size", + "1", + ] + ) + + main(args) + + +def test_prob_unet_2d_geco(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_2d_geco", + "--num_workers", + "1", + "--limit_train_batches", + "0.025", + "--loss_type", + "geco", + "--prob_type", + "prob", + "--max_epochs", + "1", + "--ndims", + "2", + "--filters_per_layer", + "2", + "4", + ] + ) + + main(args) + + +def test_prob_unet_3d_geco(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_3d_geco", + "--num_workers", + "1", + "--limit_train_batches", + "0.1", + "--loss_type", + "geco", + "--prob_type", + "prob", + "--max_epochs", + "1", + "--ndims", + "3", + "--filters_per_layer", + "2", + "4", + "--batch_size", + "1", + ] + ) + + main(args) + + +def test_hierarchical_prob_unet_2d_geco(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_2d_geco", + "--num_workers", + "1", + "--limit_train_batches", + "0.025", + "--loss_type", + "geco", + "--prob_type", + "hierarchical", + "--max_epochs", + "1", + "--ndims", + "2", + "--filters_per_layer", + "2", + "4", + ] + ) + + main(args) + + +def test_hierarchical_prob_unet_2d_geco_contour(trainer_arg_parser): + + args = trainer_arg_parser.parse_args( + [ + "--working_dir", + "test_prob_unet_2d_geco", + "--num_workers", + "1", + "--limit_train_batches", + "0.025", + "--loss_type", + "geco", + "--prob_type", + "hierarchical", + "--max_epochs", + "1", + "--ndims", + "2", + "--filters_per_layer", + "2", + "4", + "--kappa_contour", + "0.01", + ] + ) + + main(args) diff --git a/platipy/imaging/utils/crop.py b/platipy/imaging/utils/crop.py index 148b030e..31e4d1b7 100644 --- a/platipy/imaging/utils/crop.py +++ b/platipy/imaging/utils/crop.py @@ -45,6 +45,13 @@ def label_to_roi(label, expansion_mm=[0, 0, 0], return_as_list=False): label_stats_image_filter.Execute(reference_label, reference_label) bounding_box = np.array(label_stats_image_filter.GetBoundingBox(1)) + # If bounding_box is empty then the mask is likely empty. Just return entire mask as ROI. + if bounding_box.size == 0: + if return_as_list: + return [0, 0, 0] + [int(x) for x in label.GetSize()] + + return [int(x) for x in label.GetSize()], [0, 0, 0] + index = [bounding_box[x * 2] for x in range(3)] size = [bounding_box[(x * 2) + 1] - bounding_box[x * 2] + 1 for x in range(3)] diff --git a/requirements-dl.txt b/requirements-dl.txt new file mode 100644 index 00000000..3e711455 --- /dev/null +++ b/requirements-dl.txt @@ -0,0 +1,4 @@ +pytorch-lightning >= 1.3.7 +torch >= 1.9.0 +comet-ml >= 3.12.2 +imgaug >= 0.4.0 \ No newline at end of file diff --git a/services/nnunet/service.py b/services/nnunet/service.py index af101cff..5cbf0222 100644 --- a/services/nnunet/service.py +++ b/services/nnunet/service.py @@ -14,13 +14,14 @@ import os import subprocess +import json from pathlib import Path import logging import SimpleITK as sitk -from platipy.backend import app, DataObject, celery # pylint: disable=unused-import +from platipy.backend import app, DataObject, celery # pylint: disable=unused-import logger = logging.getLogger(__name__) @@ -32,12 +33,13 @@ "clean_sup_slices": False, } + def clean_sup_slices(mask): lssif = sitk.LabelShapeStatisticsImageFilter() max_slice_size = 0 sizes = {} - for z in range(mask.GetSize()[2]-1, -1, -1): - lssif.Execute(sitk.ConnectedComponent(mask[:,:,z])) + for z in range(mask.GetSize()[2] - 1, -1, -1): + lssif.Execute(sitk.ConnectedComponent(mask[:, :, z])) if len(lssif.GetLabels()) == 0: continue @@ -48,13 +50,40 @@ def clean_sup_slices(mask): sizes[z] = phys_size for z in sizes: - if sizes[z] > max_slice_size/2: - mask[:,:,z+1:mask.GetSize()[2]] = 0 + if sizes[z] > max_slice_size / 2: + mask[:, :, z + 1 : mask.GetSize()[2]] = 0 break return mask +def get_structure_names(task): + # Look up structure names if we can find them dataset.json file + if "nnUNet_raw_data_base" not in os.environ: + logger.info("nnUNet_raw_data_base not set") + return {} + + raw_path = Path(os.environ["nnUNet_raw_data_base"]) + task_path = raw_path.joinpath("nnUNet_raw_data", task) + dataset_file = task_path.joinpath("dataset.json") + + logger.info("Attempting to read %s", dataset_file) + + if not dataset_file.exists(): + logger.info("dataset.json file does not exist for %s", dataset_file) + return {} + + dataset = {} + with open(dataset_file, "r") as f: + dataset = json.load(f) + + if "labels" not in dataset: + logger.info("Something went wrong reading dataset.json file") + return {} + + return dataset["labels"] + + @app.register("nnUNet Service", default_settings=NNUNET_SETTINGS_DEFAULTS) def nnunet_service(data_objects, working_dir, settings): """ @@ -73,8 +102,10 @@ def nnunet_service(data_objects, working_dir, settings): output_path = Path(working_dir).joinpath("output") output_path.mkdir() - for data_object in data_objects: + labels = get_structure_names(settings["task"]) + logger.info("Read labels: %s", labels) + for data_object in data_objects: # Create a symbolic link for each image to auto-segment using the nnUNet do_path = Path(data_object.path) io_path = input_path.joinpath(f"{settings['task']}_0000.nii.gz") @@ -109,14 +140,29 @@ def nnunet_service(data_objects, working_dir, settings): subprocess.call(command) for op in output_path.glob("*.nii.gz"): + label_map = sitk.ReadImage(str(op)) + + label_map_arr = sitk.GetArrayFromImage(label_map) + label_count = label_map_arr.max() + + for label_id in range(1, label_count + 1): + mask = label_map == label_id - if settings["clean_sup_slices"]: - mask = sitk.ReadImage(str(op)) - mask = clean_sup_slices(mask) - sitk.WriteImage(mask, str(op)) + label_name = f"Structure_{label_id}" + if str(label_id) in labels: + label_name = labels[str(label_id)] - output_data_object = DataObject(type="FILE", path=str(op), parent=data_object) - output_objects.append(output_data_object) + if settings["clean_sup_slices"]: + mask = clean_sup_slices(mask) + + mask_file = output_path.joinpath(f"{label_name}.nii.gz") + + sitk.WriteImage(mask, str(mask_file)) + + output_data_object = DataObject( + type="FILE", path=str(mask_file), parent=data_object + ) + output_objects.append(output_data_object) os.remove(io_path) @@ -126,7 +172,6 @@ def nnunet_service(data_objects, working_dir, settings): if __name__ == "__main__": - # Run app by calling "python service.py" from the command line DICOM_LISTENER_PORT = 7777