diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 30c5be53..fe57d357 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -69,6 +69,9 @@ def __init__( self.volume_shape = volume_shape self.n_classes = n_classes + if self.n_classes < 1: + raise ValueError("n_classes must be > 0.") + @classmethod def from_tfrecords( cls, @@ -80,6 +83,7 @@ def from_tfrecords( n_classes=1, tf_dataset_options=None, num_parallel_calls=1, + label_mapping=None, ): """Function to retrieve a saved tf record as a nobrainer Dataset @@ -123,6 +127,7 @@ def from_tfrecords( if not n_volumes: n_volumes = block_length * len(files) + print(f"n_volumes: {n_volumes}") dataset = dataset.interleave( map_func=lambda x: tf.data.TFRecordDataset( @@ -138,7 +143,11 @@ def from_tfrecords( if block_shape: ds_obj.block(block_shape) if not scalar_labels: - ds_obj.map_labels() + ds_obj.map_labels( + label_mapping=label_mapping, num_parallel_calls=num_parallel_calls + ) + + # ds_obj.filter_zero_volumes() # TODO automatically determine batch size ds_obj.batch(1) @@ -158,6 +167,7 @@ def from_files( n_classes=1, block_shape=None, tf_dataset_options=None, + label_mapping=None, ): """Create Nobrainer datasets from data filepaths: List(str), list of paths to individual input data files. @@ -221,6 +231,7 @@ def from_files( block_shape=block_shape, tf_dataset_options=tf_dataset_options, num_parallel_calls=num_parallel_calls, + label_mapping=label_mapping, ) ds_eval = None if n_eval > 0: @@ -234,6 +245,7 @@ def from_files( block_shape=block_shape, tf_dataset_options=tf_dataset_options, num_parallel_calls=num_parallel_calls, + label_mapping=label_mapping, ) return ds_train, ds_eval @@ -315,19 +327,31 @@ def _f(x, y): self.dataset = self.dataset.unbatch() return self - def map_labels(self, label_mapping=None): - if self.n_classes < 1: - raise ValueError("n_classes must be > 0.") - + def map_labels(self, label_mapping=None, num_parallel_calls=1): if label_mapping is not None: - self.map(lambda x, y: (x, replace(y, label_mapping=label_mapping))) + self.map(lambda x, y: (x, replace(y, mapping=label_mapping))) if self.n_classes == 1: - self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1))) + self.map( + lambda x, y: (x, tf.expand_dims(binarize(y), -1)), + num_parallel_calls=num_parallel_calls, + ) elif self.n_classes == 2: - self.map(lambda x, y: (x, tf.one_hot(binarize(y), self.n_classes))) + self.map( + lambda x, y: ( + x, + tf.one_hot(tf.cast(binarize(y), dtype=tf.int32), self.n_classes), + ), + num_parallel_calls=num_parallel_calls, + ) elif self.n_classes > 2: - self.map(lambda x, y: (x, tf.one_hot(y, self.n_classes))) + self.map( + lambda x, y: ( + x, + tf.one_hot(tf.cast(y, dtype=tf.int32), self.n_classes), + ), + num_parallel_calls=num_parallel_calls, + ) return self @@ -363,3 +387,9 @@ def repeat(self, n_repeats): # through once. self.dataset = self.dataset.repeat(n_repeats) return self + + def filter_zero_volumes(self): + self.dataset = self.dataset.filter( + lambda x, y: tf.cast(tf.math.reduce_sum(y), dtype="bool") + ) + return self diff --git a/nobrainer/ext/SynthSeg/__init__.py b/nobrainer/ext/SynthSeg/__init__.py new file mode 100644 index 00000000..151b2e92 --- /dev/null +++ b/nobrainer/ext/SynthSeg/__init__.py @@ -0,0 +1 @@ +from . import model_inputs diff --git a/nobrainer/ext/SynthSeg/model_inputs.py b/nobrainer/ext/SynthSeg/model_inputs.py new file mode 100644 index 00000000..e6d9cda2 --- /dev/null +++ b/nobrainer/ext/SynthSeg/model_inputs.py @@ -0,0 +1,198 @@ +""" +If you use this code, please cite one of the SynthSeg papers: +https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + +# python imports +import numpy as np +import numpy.random as npr + +# third-party imports +from nobrainer.ext.lab2im import utils + + +def build_model_inputs( + path_label_maps, + n_labels, + batchsize=1, + n_channels=1, + subjects_prob=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, +): + """ + This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the + input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch). + :param path_label_maps: list of the paths of the input label maps. + :param n_labels: number of labels in the input label maps. + :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. + :param n_channels: (optional) number of channels to be synthesised. Default is 1. + :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick + the provided label maps at each minibatch. Must be a 1D numpy array, as long as path_label_maps. + :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity + distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence or a + 1d numpy array. It should have the same length as generation_labels, and contain values between 0 and K-1, where K + is the total number of classes. Default is all labels have different classes. + :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. + Can either be 'uniform', or 'normal'. Default is 'uniform'. + :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because + these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: + 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is + uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each + mini_batch from the same distribution. + 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is + not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each mini-batch + from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, or from + N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. + 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived + from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a + modality from the n_mod possibilities, and we sample the GMM means like in 2). + If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel + (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. + 4) the path to such a numpy array. + Default is None, which corresponds to prior_means = [25, 225]. + :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. + Default is None, which corresponds to prior_stds = [5, 25]. + :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be + only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. + :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default + values for half of these cases, and thus generate images of random contrast. + """ + + # allocate unique class to each label if generation classes is not given + if generation_classes is None: + generation_classes = np.arange(n_labels) + n_classes = len(np.unique(generation_classes)) + + # make sure subjects_prob sums to 1 + subjects_prob = utils.load_array_if_path(subjects_prob) + if subjects_prob is not None: + subjects_prob /= np.sum(subjects_prob) + + # Generate! + while True: + + # randomly pick as many images as batchsize + indices = npr.choice( + np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob + ) + + # initialise input lists + list_label_maps = [] + list_means = [] + list_stds = [] + + for idx in indices: + + # load input label map + lab = utils.load_volume( + path_label_maps[idx], dtype="int", aff_ref=np.eye(4) + ) + if (npr.uniform() > 0.7) & ("seg_cerebral" in path_label_maps[idx]): + lab[lab == 24] = 0 + + # add label map to inputs + list_label_maps.append(utils.add_axis(lab, axis=[0, -1])) + + # add means and standard deviations to inputs + means = np.empty((1, n_labels, 0)) + stds = np.empty((1, n_labels, 0)) + for channel in range(n_channels): + + # retrieve channel specific stats if necessary + if isinstance(prior_means, np.ndarray): + if (prior_means.shape[0] > 2) & use_specific_stats_for_channel: + if prior_means.shape[0] / 2 != n_channels: + raise ValueError( + "the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_means = prior_means[2 * channel : 2 * channel + 2, :] + else: + tmp_prior_means = prior_means + else: + tmp_prior_means = prior_means + if ( + (prior_means is not None) + & mix_prior_and_random + & (npr.uniform() > 0.5) + ): + tmp_prior_means = None + if isinstance(prior_stds, np.ndarray): + if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel: + if prior_stds.shape[0] / 2 != n_channels: + raise ValueError( + "the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_stds = prior_stds[2 * channel : 2 * channel + 2, :] + else: + tmp_prior_stds = prior_stds + else: + tmp_prior_stds = prior_stds + if ( + (prior_stds is not None) + & mix_prior_and_random + & (npr.uniform() > 0.5) + ): + tmp_prior_stds = None + + # draw means and std devs from priors + tmp_classes_means = utils.draw_value_from_distribution( + tmp_prior_means, + n_classes, + prior_distributions, + 125.0, + 125.0, + positive_only=True, + ) + tmp_classes_stds = utils.draw_value_from_distribution( + tmp_prior_stds, + n_classes, + prior_distributions, + 15.0, + 15.0, + positive_only=True, + ) + random_coef = npr.uniform() + if random_coef > 0.95: # reset the background to 0 in 5% of cases + tmp_classes_means[0] = 0 + tmp_classes_stds[0] = 0 + elif ( + random_coef > 0.7 + ): # reset the background to low Gaussian in 25% of cases + tmp_classes_means[0] = npr.uniform(0, 15) + tmp_classes_stds[0] = npr.uniform(0, 5) + tmp_means = utils.add_axis( + tmp_classes_means[generation_classes], axis=[0, -1] + ) + tmp_stds = utils.add_axis( + tmp_classes_stds[generation_classes], axis=[0, -1] + ) + means = np.concatenate([means, tmp_means], axis=-1) + stds = np.concatenate([stds, tmp_stds], axis=-1) + list_means.append(means) + list_stds.append(stds) + + # build list of inputs for generation model + list_inputs = [list_label_maps, list_means, list_stds] + if batchsize > 1: # concatenate each input type if batchsize > 1 + list_inputs = [np.concatenate(item, 0) for item in list_inputs] + else: + list_inputs = [item[0] for item in list_inputs] + + yield list_inputs diff --git a/nobrainer/ext/__init__.py b/nobrainer/ext/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nobrainer/ext/lab2im/__init__.py b/nobrainer/ext/lab2im/__init__.py new file mode 100644 index 00000000..f26d7db9 --- /dev/null +++ b/nobrainer/ext/lab2im/__init__.py @@ -0,0 +1 @@ +from . import edit_tensors, edit_volumes, image_generator, lab2im_model, layers, utils diff --git a/nobrainer/ext/lab2im/edit_tensors.py b/nobrainer/ext/lab2im/edit_tensors.py new file mode 100644 index 00000000..c34d0374 --- /dev/null +++ b/nobrainer/ext/lab2im/edit_tensors.py @@ -0,0 +1,437 @@ +""" + +This file contains functions to handle keras/tensorflow tensors. + - blurring_sigma_for_downsampling + - gaussian_kernel + - resample_tensor + - expand_dims + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. + +""" + +from itertools import combinations + +import keras.backend as K +import keras.layers as KL + +# python imports +import numpy as np +import tensorflow as tf + +# project imports +from nobrainer.ext.lab2im import utils + +# third-party imports +import nobrainer.ext.neuron.layers as nrn_layers +from nobrainer.ext.neuron.utils import volshape_to_meshgrid + + +def blurring_sigma_for_downsampling( + current_res, downsample_res, mult_coef=None, thickness=None +): + """Compute standard deviations of 1d gaussian masks for image blurring before downsampling. + :param downsample_res: resolution to downsample to. Can be a 1d numpy array or list, or a tensor. + :param current_res: resolution of the volume before downsampling. + Can be a 1d numpy array or list or tensor of the same length as downsample res. + :param mult_coef: (optional) multiplicative coefficient for the blurring kernel. Default is 0.75. + :param thickness: (optional) slice thickness in each dimension. Must be the same type as downsample_res. + :return: standard deviation of the blurring masks given as the same type as downsample_res (list or tensor). + """ + + if not tf.is_tensor(downsample_res): + + # get blurring resolution (min between downsample_res and thickness) + current_res = np.array(current_res) + downsample_res = np.array(downsample_res) + if thickness is not None: + downsample_res = np.minimum(downsample_res, np.array(thickness)) + + # get std deviation for blurring kernels + if mult_coef is None: + sigma = 0.75 * downsample_res / current_res + sigma[downsample_res == current_res] = 0.5 + else: + sigma = mult_coef * downsample_res / current_res + sigma[downsample_res == 0] = 0 + + else: + + # reformat data resolution at which we blur + if thickness is not None: + down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))( + [downsample_res, thickness] + ) + else: + down_res = downsample_res + + # get std deviation for blurring kernels + if mult_coef is None: + sigma = KL.Lambda( + lambda x: tf.where( + tf.math.equal( + x, tf.convert_to_tensor(current_res, dtype="float32") + ), + 0.5, + 0.75 * x / tf.convert_to_tensor(current_res, dtype="float32"), + ) + )(down_res) + else: + sigma = KL.Lambda( + lambda x: mult_coef + * x + / tf.convert_to_tensor(current_res, dtype="float32") + )(down_res) + sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.0), 0.0, x[1]))( + [down_res, sigma] + ) + + return sigma + + +def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): + """Build gaussian kernels of the specified standard deviation. The outputs are given as tensorflow tensors. + :param sigma: standard deviation of the tensors. Can be given as a list/numpy array or as tensors. In each case, + sigma must have the same length as the number of dimensions of the volume that will be blurred with the output + tensors (e.g. sigma must have 3 values for 3D volumes). + :param max_sigma: + :param blur_range: + :param separable: + :return: + """ + # convert sigma into a tensor + if not tf.is_tensor(sigma): + sigma_tens = tf.convert_to_tensor( + utils.reformat_to_list(sigma), dtype="float32" + ) + else: + assert ( + max_sigma is not None + ), "max_sigma must be provided when sigma is given as a tensor" + sigma_tens = sigma + shape = sigma_tens.get_shape().as_list() + + # get n_dims and batchsize + if shape[0] is not None: + n_dims = shape[0] + batchsize = None + else: + n_dims = shape[1] + batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0] + + # reformat max_sigma + if max_sigma is not None: # dynamic blurring + max_sigma = np.array(utils.reformat_to_list(max_sigma, length=n_dims)) + else: # sigma is fixed + max_sigma = np.array(utils.reformat_to_list(sigma, length=n_dims)) + + # randomise the burring std dev and/or split it between dimensions + if blur_range is not None: + if blur_range != 1: + sigma_tens = sigma_tens * tf.random.uniform( + tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range + ) + + # get size of blurring kernels + windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1 + + if separable: + + split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1) + + kernels = list() + comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) + for i, wsize in enumerate(windowsize): + + if wsize > 1: + + # build meshgrid and replicate it along batch dim if dynamic blurring + locations = tf.cast(tf.range(0, wsize), "float32") - (wsize - 1) / 2 + if batchsize is not None: + locations = tf.tile( + tf.expand_dims(locations, axis=0), + tf.concat( + [ + batchsize, + tf.ones(tf.shape(tf.shape(locations)), dtype="int32"), + ], + axis=0, + ), + ) + comb[i] += 1 + + # compute gaussians + exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2) + g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i])) + g = g / tf.reduce_sum(g) + + for axis in comb[i]: + g = tf.expand_dims(g, axis=axis) + kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1)) + + else: + kernels.append(None) + + else: + + # build meshgrid + mesh = [ + tf.cast(f, "float32") + for f in volshape_to_meshgrid(windowsize, indexing="ij") + ] + diff = tf.stack( + [mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1 + ) + + # replicate meshgrid to batch size and reshape sigma_tens + if batchsize is not None: + diff = tf.tile( + tf.expand_dims(diff, axis=0), + tf.concat( + [batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype="int32")], + axis=0, + ), + ) + for i in range(n_dims): + sigma_tens = tf.expand_dims(sigma_tens, axis=1) + else: + for i in range(n_dims): + sigma_tens = tf.expand_dims(sigma_tens, axis=0) + + # compute gaussians + sigma_is_0 = tf.equal(sigma_tens, 0) + exp_term = -K.square(diff) / ( + 2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens) ** 2 + ) + norms = exp_term - tf.math.log( + tf.where( + sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens + ) + ) + kernels = K.sum(norms, -1) + kernels = tf.exp(kernels) + kernels /= tf.reduce_sum(kernels) + kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1) + + return kernels + + +def sobel_kernels(n_dims): + """Returns sobel kernels to compute spatial derivative on image of n dimensions.""" + + in_dir = tf.convert_to_tensor([1, 0, -1], dtype="float32") + orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype="float32") + comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) + + list_kernels = list() + for dim in range(n_dims): + + sublist_kernels = list() + for axis in range(n_dims): + + kernel = in_dir if axis == dim else orthogonal_dir + for i in comb[axis]: + kernel = tf.expand_dims(kernel, axis=i) + sublist_kernels.append(tf.expand_dims(tf.expand_dims(kernel, -1), -1)) + + list_kernels.append(sublist_kernels) + + return list_kernels + + +def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None): + """Build kernel with values of 1 for voxel at a distance < dist_threshold from the center, and 0 otherwise. + The outputs are given as tensorflow tensors. + :param dist_threshold: maximum distance from the center until voxel will have a value of 1. Can be a tensor of size + (batch_size, 1), or a float. + :param n_dims: dimension of the kernel to return (excluding batch and channel dimensions). + :param max_dist_threshold: if distance_threshold is a tensor, max_dist_threshold must be given. It represents the + maximum value that will be passed to dist_threshold. Must be a float. + """ + + # convert dist_threshold into a tensor + if not tf.is_tensor(dist_threshold): + dist_threshold_tens = tf.convert_to_tensor( + utils.reformat_to_list(dist_threshold), dtype="float32" + ) + else: + assert ( + max_dist_threshold is not None + ), "max_sigma must be provided when dist_threshold is given as a tensor" + dist_threshold_tens = tf.cast(dist_threshold, "float32") + shape = dist_threshold_tens.get_shape().as_list() + + # get batchsize + batchsize = ( + None + if shape[0] is not None + else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0] + ) + + # set max_dist_threshold into an array + if ( + max_dist_threshold is None + ): # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch) + max_dist_threshold = dist_threshold + + # get size of blurring kernels + windowsize = np.array([max_dist_threshold * 2 + 1] * n_dims, dtype="int32") + + # build tensor representing the distance from the centre + mesh = [ + tf.cast(f, "float32") for f in volshape_to_meshgrid(windowsize, indexing="ij") + ] + dist = tf.stack( + [mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1 + ) + dist = tf.sqrt(tf.reduce_sum(tf.square(dist), axis=-1)) + + # replicate distance to batch size and reshape sigma_tens + if batchsize is not None: + dist = tf.tile( + tf.expand_dims(dist, axis=0), + tf.concat( + [batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype="int32")], axis=0 + ), + ) + for i in range(n_dims - 1): + dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=1) + else: + for i in range(n_dims - 1): + dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=0) + + # build final kernel by thresholding distance tensor + kernel = tf.where( + tf.less_equal(dist, dist_threshold_tens), + tf.ones_like(dist), + tf.zeros_like(dist), + ) + kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) + + return kernel + + +def resample_tensor( + tensor, + resample_shape, + interp_method="linear", + subsample_res=None, + volume_res=None, + build_reliability_map=False, +): + """This function resamples a volume to resample_shape. It does not apply any pre-filtering. + A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be + specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which + slices were interpolated during resampling from the downsampled to final tensor. + :param tensor: tensor + :param resample_shape: list or numpy array of size (n_dims,) + :param interp_method: (optional) interpolation method for resampling, 'linear' (default) or 'nearest' + :param subsample_res: (optional) if not None, this triggers a downsampling of the volume, prior to the resampling + step. List or numpy array of size (n_dims,). Default si None. + :param volume_res: (optional) if subsample_res is not None, this should be provided to compute downsampling ratio. + list or numpy array of size (n_dims,). Default is None. + :param build_reliability_map: whether to return reliability map along with the resampled tensor. This map indicates + which slices of the resampled tensor are interpolated (0=interpolated, 1=real slice, in between=degree of realness). + :return: resampled volume, with reliability map if necessary. + """ + + # reformat resolutions to lists + subsample_res = utils.reformat_to_list(subsample_res) + volume_res = utils.reformat_to_list(volume_res) + n_dims = len(resample_shape) + + # downsample image + tensor_shape = tensor.get_shape().as_list()[1:-1] + downsample_shape = tensor_shape # will be modified if we actually downsample + + if subsample_res is not None: + assert ( + volume_res is not None + ), "volume_res must be given when providing a subsampling resolution." + assert len(subsample_res) == len(volume_res), ( + "subsample_res and volume_res must have the same length, " + "had {0}, and {1}".format(len(subsample_res), len(volume_res)) + ) + if subsample_res != volume_res: + + # get shape at which we downsample + downsample_shape = [ + int(tensor_shape[i] * volume_res[i] / subsample_res[i]) + for i in range(n_dims) + ] + + # downsample volume + tensor._keras_shape = tuple(tensor.get_shape().as_list()) + tensor = nrn_layers.Resize(size=downsample_shape, interp_method="nearest")( + tensor + ) + + # resample image at target resolution + if ( + resample_shape != downsample_shape + ): # if we didn't downsample downsample_shape = tensor_shape + tensor._keras_shape = tuple(tensor.get_shape().as_list()) + tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)( + tensor + ) + + # compute reliability maps if necessary and return results + if build_reliability_map: + + # compute maps only if we downsampled + if downsample_shape != tensor_shape: + + # compute upsampling factors + upsampling_factors = np.array(resample_shape) / np.array(downsample_shape) + + # build reliability map + reliability_map = 1 + for i in range(n_dims): + loc_float = np.arange(0, resample_shape[i], upsampling_factors[i]) + loc_floor = np.int32(np.floor(loc_float)) + loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1)) + tmp_reliability_map = np.zeros(resample_shape[i]) + tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor) + tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + ( + loc_float - loc_floor + ) + shape = [1, 1, 1] + shape[i] = resample_shape[i] + reliability_map = reliability_map * np.reshape( + tmp_reliability_map, shape + ) + shape = KL.Lambda(lambda x: tf.shape(x))(tensor) + mask = KL.Lambda( + lambda x: tf.reshape( + tf.convert_to_tensor(reliability_map, dtype="float32"), shape=x + ) + )(shape) + + # otherwise just return an all-one tensor + else: + mask = KL.Lambda(lambda x: tf.ones_like(x))(tensor) + + return tensor, mask + + else: + return tensor + + +def expand_dims(tensor, axis=0): + """Expand the dimensions of the input tensor along the provided axes (given as an integer or a list).""" + axis = utils.reformat_to_list(axis) + for ax in axis: + tensor = tf.expand_dims(tensor, axis=ax) + return tensor diff --git a/nobrainer/ext/lab2im/edit_volumes.py b/nobrainer/ext/lab2im/edit_volumes.py new file mode 100644 index 00000000..9bed9453 --- /dev/null +++ b/nobrainer/ext/lab2im/edit_volumes.py @@ -0,0 +1,3591 @@ +""" +This file contains functions to edit/preprocess volumes (i.e. not tensors!). +These functions are sorted in five categories: +1- volume editing: this can be applied to any volume (i.e. images or label maps). It contains: + -mask_volume + -rescale_volume + -crop_volume + -crop_volume_around_region + -crop_volume_with_idx + -pad_volume + -flip_volume + -resample_volume + -resample_volume_like + -get_ras_axes + -align_volume_to_ref + -blur_volume +2- label map editing: can be applied to label maps only. It contains: + -correct_label_map + -mask_label_map + -smooth_label_map + -erode_label_map + -get_largest_connected_component + -compute_hard_volumes + -compute_distance_map +3- editing all volumes in a folder: functions are more or less the same as 1, but they now apply to all the volumes +in a given folder. Thus we provide folder paths rather than numpy arrays as inputs. It contains: + -mask_images_in_dir + -rescale_images_in_dir + -crop_images_in_dir + -crop_images_around_region_in_dir + -pad_images_in_dir + -flip_images_in_dir + -align_images_in_dir + -correct_nans_images_in_dir + -blur_images_in_dir + -create_mutlimodal_images + -convert_images_in_dir_to_nifty + -mri_convert_images_in_dir + -samseg_images_in_dir + -niftyreg_images_in_dir + -upsample_anisotropic_images + -simulate_upsampled_anisotropic_images + -check_images_in_dir +4- label maps in dir: same as 3 but for label map-specific functions. It contains: + -correct_labels_in_dir + -mask_labels_in_dir + -smooth_labels_in_dir + -erode_labels_in_dir + -upsample_labels_in_dir + -compute_hard_volumes_in_dir + -build_atlas +5- dataset editing: functions for editing datasets (i.e. images with corresponding label maps). It contains: + -check_images_and_labels + -crop_dataset_to_minimum_size + -subdivide_dataset_to_patches + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + +import csv + +# python imports +import os +import shutil + +import keras.layers as KL +from keras.models import Model +import numpy as np +from scipy.interpolate import RegularGridInterpolator +from scipy.ndimage import binary_dilation, binary_erosion, gaussian_filter +from scipy.ndimage import label as scipy_label +from scipy.ndimage.filters import convolve +from scipy.ndimage.morphology import binary_fill_holes, distance_transform_edt +import tensorflow as tf + +# project imports +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im.edit_tensors import blurring_sigma_for_downsampling +from nobrainer.ext.lab2im.layers import ConvertLabels, GaussianBlur + +# ---------------------------------------------------- edit volume ----------------------------------------------------- + + +def mask_volume( + volume, + mask=None, + threshold=0.1, + dilate=0, + erode=0, + fill_holes=False, + masking_value=0, + return_mask=False, + return_copy=True, +): + """Mask a volume, either with a given mask, or by keeping only the values above a threshold. + :param volume: a numpy array, possibly with several channels + :param mask: (optional) a numpy array to mask volume with. + Mask doesn't have to be a 0/1 array, all strictly positive values of mask are considered for masking volume. + Mask should have the same size as volume. If volume has several channels, mask can either be uni- or multi-channel. + In the first case, the same mask is applied to all channels. + :param threshold: (optional) If mask is None, masking is performed by keeping thresholding the input. + :param dilate: (optional) number of voxels by which to dilate the provided or computed mask. + :param erode: (optional) number of voxels by which to erode the provided or computed mask. + :param fill_holes: (optional) whether to fill the holes in the provided or computed mask. + :param masking_value: (optional) masking value + :param return_mask: (optional) whether to return the applied mask + :param return_copy: (optional) whether to return the original volume or a copy. Default is copy. + :return: the masked volume, and the applied mask if return_mask is True. + """ + + # get info + new_volume = volume.copy() if return_copy else volume + vol_shape = list(new_volume.shape) + n_dims, n_channels = utils.get_dims(vol_shape) + + # get mask and erode/dilate it + if mask is None: + mask = new_volume >= threshold + else: + assert ( + list(mask.shape[:n_dims]) == vol_shape[:n_dims] + ), "mask should have shape {0}, or {1}, had {2}".format( + vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape) + ) + mask = mask > 0 + if dilate > 0: + dilate_struct = utils.build_binary_structure(dilate, n_dims) + mask_to_apply = binary_dilation(mask, dilate_struct) + else: + mask_to_apply = mask + if erode > 0: + erode_struct = utils.build_binary_structure(erode, n_dims) + mask_to_apply = binary_erosion(mask_to_apply, erode_struct) + if fill_holes: + mask_to_apply = binary_fill_holes(mask_to_apply) + + # replace values outside of mask by padding_char + if mask_to_apply.shape == new_volume.shape: + new_volume[np.logical_not(mask_to_apply)] = masking_value + else: + new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = ( + masking_value + ) + + if return_mask: + return new_volume, mask_to_apply + else: + return new_volume + + +def rescale_volume( + volume, + new_min=0, + new_max=255, + min_percentile=2, + max_percentile=98, + use_positive_only=False, +): + """This function linearly rescales a volume between new_min and new_max. + :param volume: a numpy array + :param new_min: (optional) minimum value for the rescaled image. + :param new_max: (optional) maximum value for the rescaled image. + :param min_percentile: (optional) percentile for estimating robust minimum of volume (float in [0,...100]), + where 0 = np.min + :param max_percentile: (optional) percentile for estimating robust maximum of volume (float in [0,...100]), + where 100 = np.max + :param use_positive_only: (optional) whether to use only positive values when estimating the min and max percentile + :return: rescaled volume + """ + + # select only positive intensities + new_volume = volume.copy() + intensities = ( + new_volume[new_volume > 0] if use_positive_only else new_volume.flatten() + ) + + # define min and max intensities in original image for normalisation + robust_min = ( + np.min(intensities) + if min_percentile == 0 + else np.percentile(intensities, min_percentile) + ) + robust_max = ( + np.max(intensities) + if max_percentile == 100 + else np.percentile(intensities, max_percentile) + ) + + # trim values outside range + new_volume = np.clip(new_volume, robust_min, robust_max) + + # rescale image + if robust_min != robust_max: + return new_min + (new_volume - robust_min) / (robust_max - robust_min) * ( + new_max - new_min + ) + else: # avoid dividing by zero + return np.zeros_like(new_volume) + + +def crop_volume( + volume, + cropping_margin=None, + cropping_shape=None, + aff=None, + return_crop_idx=False, + mode="center", +): + """Crop volume by a given margin, or to a given shape. + :param volume: 2d or 3d numpy array (possibly with multiple channels) + :param cropping_margin: (optional) margin by which to crop the volume. The cropping margin is applied on both sides. + Can be an int, sequence or 1d numpy array of size n_dims. Should be given if cropping_shape is None. + :param cropping_shape: (optional) shape to which the volume will be cropped. Can be an int, sequence or 1d numpy + array of size n_dims. Should be given if cropping_margin is None. + :param aff: (optional) affine matrix of the input volume. + If not None, this function also returns an updated version of the affine matrix for the cropped volume. + :param return_crop_idx: (optional) whether to return the cropping indices used to crop the given volume. + :param mode: (optional) if cropping_shape is not None, whether to extract the centre of the image (mode='center'), + or to randomly crop the volume to the provided shape (mode='random'). Default is 'center'. + :return: cropped volume, corresponding affine matrix if aff is not None, and cropping indices if return_crop_idx is + True (in that order). + """ + + assert (cropping_margin is not None) | ( + cropping_shape is not None + ), "cropping_margin or cropping_shape should be provided" + assert not ( + (cropping_margin is not None) & (cropping_shape is not None) + ), "only one of cropping_margin or cropping_shape should be provided" + + # get info + new_volume = volume.copy() + vol_shape = new_volume.shape + n_dims, _ = utils.get_dims(vol_shape) + + # find cropping indices + if cropping_margin is not None: + cropping_margin = utils.reformat_to_list(cropping_margin, length=n_dims) + do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin) + min_crop_idx = [ + cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims) + ] + max_crop_idx = [ + vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] + for i in range(n_dims) + ] + else: + cropping_shape = utils.reformat_to_list(cropping_shape, length=n_dims) + if mode == "center": + min_crop_idx = np.maximum( + [int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0 + ) + max_crop_idx = np.minimum( + [min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)], + np.array(vol_shape)[:n_dims], + ) + elif mode == "random": + crop_max_val = np.maximum( + np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0 + ) + min_crop_idx = np.random.randint(0, high=crop_max_val + 1) + max_crop_idx = np.minimum( + min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims] + ) + else: + raise ValueError( + 'mode should be either "center" or "random", had %s' % mode + ) + crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)]) + + # crop volume + if n_dims == 2: + new_volume = new_volume[ + crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ... + ] + elif n_dims == 3: + new_volume = new_volume[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] + + # sort outputs + output = [new_volume] + if aff is not None: + aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ np.array(min_crop_idx) + output.append(aff) + if return_crop_idx: + output.append(crop_idx) + return output[0] if len(output) == 1 else tuple(output) + + +def crop_volume_around_region( + volume, + mask=None, + masking_labels=None, + threshold=0.1, + margin=0, + cropping_shape=None, + cropping_shape_div_by=None, + aff=None, + overflow="strict", +): + """Crop a volume around a specific region. + This region is defined by a mask obtained by either: + 1) directly specifying it as input (see mask) + 2) keeping a set of label values if the volume is a label map (see masking_labels). + 3) thresholding the input volume (see threshold) + The cropping region is defined by the bounding box of the mask, which we can further modify by either: + 1) extending it by a margin (see margin) + 2) providing a specific cropping shape, in this case the cropping region will be centered around the bounding box + (see cropping_shape). + 3) extending it to a shape that is divisible by a given number. Again, the cropping region will be centered around + the bounding box (see cropping_shape_div_by). + Finally, if the size of the cropping region has been modified, and that this modified size overflows out of the + image (e.g. because the center of the mask is close to the edge), we can either: + 1) stick to the valid image space (the size of the modified cropping region won't be respected) + 2) shift the cropping region so that it lies on the valid image space, and if it still overflows, then we restrict + to the valid image space. + 3) pad the image with zeros, such that the cropping region is not ill-defined anymore. + 3) shift the cropping region to the valida image space, and if it still overflows, then we pad with zeros. + :param volume: a 2d or 3d numpy array + :param mask: (optional) mask of region to crop around. Must be same size as volume. Can either be boolean or 0/1. + If no mask is given, it will be computed by either thresholding the input volume or using masking_labels. + :param masking_labels: (optional) if mask is None, and if the volume is a label map, it can be cropped around a + set of labels specified in masking_labels, which can either be a single int, a sequence or a 1d numpy array. + :param threshold: (optional) if mask amd masking_labels are None, lower bound to determine values to crop around. + :param margin: (optional) add margin around mask + :param cropping_shape: (optional) shape to which the input volumes must be cropped. Volumes are padded around the + centre of the above-defined mask is they are too small for the given shape. Can be an integer or sequence. + Cannot be given at the same time as margin or cropping_shape_div_by. + :param cropping_shape_div_by: (optional) makes sure the shape of the cropped region is divisible by the provided + number. If it is not, then we enlarge the cropping area. If the enlarged area is too big for the input volume, we + pad it with 0. Must be an integer. Cannot be given at the same time as margin or cropping_shape. + :param aff: (optional) if specified, this function returns an updated affine matrix of the volume after cropping. + :param overflow: (optional) how to proceed when the cropping region overflows outside the initial image space. + Can either be 'strict' (default), 'shift-strict', 'padding', 'shift-padding. + :return: the cropped volume, the cropping indices (in the order [lower_bound_dim_1, ..., upper_bound_dim_1, ...]), + and the updated affine matrix if aff is not None. + """ + + assert not ( + (margin > 0) & (cropping_shape is not None) + ), "margin and cropping_shape can't be given together." + assert not ( + (margin > 0) & (cropping_shape_div_by is not None) + ), "margin and cropping_shape_div_by can't be given together." + assert not ( + (cropping_shape_div_by is not None) & (cropping_shape is not None) + ), "cropping_shape_div_by and cropping_shape can't be given together." + + new_vol = volume.copy() + n_dims, n_channels = utils.get_dims(new_vol.shape) + vol_shape = np.array(new_vol.shape[:n_dims]) + + # mask ROIs for cropping + if mask is None: + if masking_labels is not None: + _, mask = mask_label_map( + new_vol, masking_values=masking_labels, return_mask=True + ) + else: + mask = new_vol > threshold + + # find cropping indices + if np.any(mask): + + indices = np.nonzero(mask) + min_idx = np.array([np.min(idx) for idx in indices]) + max_idx = np.array([np.max(idx) for idx in indices]) + intermediate_vol_shape = max_idx - min_idx + + if (margin == 0) & (cropping_shape is None) & (cropping_shape_div_by is None): + cropping_shape = intermediate_vol_shape + if margin: + cropping_shape = intermediate_vol_shape + 2 * margin + elif cropping_shape is not None: + cropping_shape = np.array( + utils.reformat_to_list(cropping_shape, length=n_dims) + ) + elif cropping_shape_div_by is not None: + cropping_shape = [ + utils.find_closest_number_divisible_by_m( + s, cropping_shape_div_by, answer_type="higher" + ) + for s in intermediate_vol_shape + ] + + min_idx = min_idx - np.int32( + np.ceil((cropping_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((cropping_shape - intermediate_vol_shape) / 2) + ) + min_overflow = np.abs(np.minimum(min_idx, 0)) + max_overflow = np.maximum(max_idx - vol_shape, 0) + + if "strict" in overflow: + min_overflow = np.zeros_like(min_overflow) + max_overflow = np.zeros_like(min_overflow) + + if overflow == "shift-strict": + min_idx -= max_overflow + max_idx += min_overflow + + if overflow == "shift-padding": + for ii in range(n_dims): + # no need to do anything if both min/max_overflow are 0 (no padding/shifting required at all) + # or if both are positive, because in this case we don't shift at all and we pad directly + if (min_overflow[ii] > 0) & (max_overflow[ii] == 0): + max_idx_new = max_idx[ii] + min_overflow[ii] + if max_idx_new <= vol_shape[ii]: + max_idx[ii] = max_idx_new + min_overflow[ii] = 0 + else: + min_overflow[ii] = min_overflow[ii] - ( + vol_shape[ii] - max_idx[ii] + ) + max_idx[ii] = vol_shape[ii] + elif (min_overflow[ii] == 0) & (max_overflow[ii] > 0): + min_idx_new = min_idx[ii] - max_overflow[ii] + if min_idx_new >= 0: + min_idx[ii] = min_idx_new + max_overflow[ii] = 0 + else: + max_overflow[ii] = max_overflow[ii] - min_idx[ii] + min_idx[ii] = 0 + + # crop volume if necessary + min_idx = np.maximum(min_idx, 0) + max_idx = np.minimum(max_idx, vol_shape) + cropping = np.concatenate([min_idx, max_idx]) + if np.any(cropping[:3] > 0) or np.any(cropping[3:] != vol_shape): + if n_dims == 3: + new_vol = new_vol[ + cropping[0] : cropping[3], + cropping[1] : cropping[4], + cropping[2] : cropping[5], + ..., + ] + elif n_dims == 2: + new_vol = new_vol[ + cropping[0] : cropping[2], cropping[1] : cropping[3], ... + ] + else: + raise ValueError("cannot crop volumes with more than 3 dimensions") + + # pad volume if necessary + if np.any(min_overflow > 0) | np.any(max_overflow > 0): + pad_margins = tuple( + [(min_overflow[i], max_overflow[i]) for i in range(n_dims)] + ) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins + ) + new_vol = np.pad(new_vol, pad_margins, mode="constant", constant_values=0) + + # if there's nothing to crop around, we return the input as is + else: + min_idx = min_overflow = np.zeros(3) + cropping = None + + # return results + if aff is not None: + if n_dims == 2: + min_idx = np.append(min_idx, 0) + min_overflow = np.append(min_overflow, 0) + aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ min_idx + aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_overflow + return new_vol, cropping, aff + else: + return new_vol, cropping + + +def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=True): + """Crop a volume with given indices. + :param volume: a 2d or 3d numpy array + :param crop_idx: cropping indices, in the order [lower_bound_dim_1, ..., upper_bound_dim_1, ...]. + Can be a list or a 1d numpy array. + :param aff: (optional) if aff is specified, this function returns an updated affine matrix of the volume after + cropping. + :param n_dims: (optional) number of dimensions (excluding channels) of the volume. If not provided, n_dims will be + inferred from the input volume. + :param return_copy: (optional) whether to return the original volume or a copy. Default is copy. + :return: the cropped volume, and the updated affine matrix if aff is not None. + """ + + # get info + new_volume = volume.copy() if return_copy else volume + n_dims = int(np.array(crop_idx).shape[0] / 2) if n_dims is None else n_dims + + # crop image + if n_dims == 2: + new_volume = new_volume[ + crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ... + ] + elif n_dims == 3: + new_volume = new_volume[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] + else: + raise Exception("cannot crop volumes with more than 3 dimensions") + + if aff is not None: + aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3] + return new_volume, aff + else: + return new_volume + + +def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=False): + """Pad volume to a given shape + :param volume: volume to be padded + :param padding_shape: shape to pad volume to. Can be a number, a sequence or a 1d numpy array. + :param padding_value: (optional) value used for padding + :param aff: (optional) affine matrix of the volume + :param return_pad_idx: (optional) the pad_idx corresponds to the indices where we should crop the resulting + padded image (ie the output of this function) to go back to the original volume (ie the input of this function). + :return: padded volume, and updated affine matrix if aff is not None. + """ + + # get info + new_volume = volume.copy() + vol_shape = new_volume.shape + n_dims, n_channels = utils.get_dims(vol_shape) + padding_shape = utils.reformat_to_list(padding_shape, length=n_dims, dtype="int") + + # check if need to pad + if np.any( + np.array(padding_shape, dtype="int32") + > np.array(vol_shape[:n_dims], dtype="int32") + ): + + # get padding margins + min_margins = np.maximum( + np.int32( + np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2) + ), + 0, + ) + max_margins = np.maximum( + np.int32( + np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2) + ), + 0, + ) + pad_idx = np.concatenate( + [min_margins, min_margins + np.array(vol_shape[:n_dims])] + ) + pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)]) + if n_channels > 1: + pad_margins = tuple(list(pad_margins) + [(0, 0)]) + + # pad volume + new_volume = np.pad( + new_volume, pad_margins, mode="constant", constant_values=padding_value + ) + + if aff is not None: + if n_dims == 2: + min_margins = np.append(min_margins, 0) + aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_margins + + else: + pad_idx = np.concatenate([np.array([0] * n_dims), np.array(vol_shape[:n_dims])]) + + # sort outputs + output = [new_volume] + if aff is not None: + output.append(aff) + if return_pad_idx: + output.append(pad_idx) + return output[0] if len(output) == 1 else tuple(output) + + +def flip_volume(volume, axis=None, direction=None, aff=None, return_copy=True): + """Flip volume along a specified axis. + If unknown, this axis can be inferred from an affine matrix with a specified anatomical direction. + :param volume: a numpy array + :param axis: (optional) axis along which to flip the volume. Can either be an int or a tuple. + :param direction: (optional) if axis is None, the volume can be flipped along an anatomical direction: + 'rl' (right/left), 'ap' anterior/posterior), 'si' (superior/inferior). + :param aff: (optional) please provide an affine matrix if direction is not None + :param return_copy: (optional) whether to return the original volume or a copy. Default is copy. + :return: flipped volume + """ + + new_volume = volume.copy() if return_copy else volume + assert (axis is not None) | ( + (aff is not None) & (direction is not None) + ), "please provide either axis, or an affine matrix with a direction" + + # get flipping axis from aff if axis not provided + if (axis is None) & (aff is not None): + volume_axes = get_ras_axes(aff) + if direction == "rl": + axis = volume_axes[0] + elif direction == "ap": + axis = volume_axes[1] + elif direction == "si": + axis = volume_axes[2] + else: + raise ValueError( + "direction should be 'rl', 'ap', or 'si', had %s" % direction + ) + + # flip volume + return np.flip(new_volume, axis=axis) + + +def resample_volume(volume, aff, new_vox_size, interpolation="linear", blur=True): + """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS + :param volume: a numpy array + :param aff: affine matrix of the volume + :param new_vox_size: new voxel size (3 - element numpy vector) in mm + :param interpolation: (optional) type of interpolation. Can be 'linear' or 'nearest'. Default is 'linear'. + :param blur: (optional) whether to blur before resampling to avoid aliasing effects. + Only used if the input volume is downsampled. Default is True. + :return: new volume and affine matrix + """ + + pixdim = np.sqrt(np.sum(aff * aff, axis=0))[:-1] + new_vox_size = np.array(new_vox_size) + factor = pixdim / new_vox_size + sigmas = 0.25 / factor + sigmas[factor > 1] = 0 # don't blur if upsampling + + volume_filt = gaussian_filter(volume, sigmas) if blur else volume + + # volume2 = zoom(volume_filt, factor, order=1, mode='reflect', prefilter=False) + x = np.arange(0, volume_filt.shape[0]) + y = np.arange(0, volume_filt.shape[1]) + z = np.arange(0, volume_filt.shape[2]) + + my_interpolating_function = RegularGridInterpolator( + (x, y, z), volume_filt, method=interpolation + ) + + start = -(factor - 1) / (2 * factor) + step = 1.0 / factor + stop = start + step * np.ceil(volume_filt.shape * factor) + + xi = np.arange(start=start[0], stop=stop[0], step=step[0]) + yi = np.arange(start=start[1], stop=stop[1], step=step[1]) + zi = np.arange(start=start[2], stop=stop[2], step=step[2]) + xi[xi < 0] = 0 + yi[yi < 0] = 0 + zi[zi < 0] = 0 + xi[xi > (volume_filt.shape[0] - 1)] = volume_filt.shape[0] - 1 + yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1 + zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1 + + xig, yig, zig = np.meshgrid(xi, yi, zi, indexing="ij", sparse=True) + volume2 = my_interpolating_function((xig, yig, zig)) + + aff2 = aff.copy() + for c in range(3): + aff2[:-1, c] = aff2[:-1, c] / factor[c] + aff2[:-1, -1] = aff2[:-1, -1] - np.matmul(aff2[:-1, :-1], 0.5 * (factor - 1)) + + return volume2, aff2 + + +def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation="linear"): + """This function reslices a floating image to the space of a reference image + :param vol_ref: a numpy array with the reference volume + :param aff_ref: affine matrix of the reference volume + :param vol_flo: a numpy array with the floating volume + :param aff_flo: affine matrix of the floating volume + :param interpolation: (optional) type of interpolation. Can be 'linear' or 'nearest'. Default is 'linear'. + :return: resliced volume + """ + + T = np.matmul(np.linalg.inv(aff_flo), aff_ref) + + xf = np.arange(0, vol_flo.shape[0]) + yf = np.arange(0, vol_flo.shape[1]) + zf = np.arange(0, vol_flo.shape[2]) + + my_interpolating_function = RegularGridInterpolator( + (xf, yf, zf), vol_flo, bounds_error=False, fill_value=0.0, method=interpolation + ) + + xr = np.arange(0, vol_ref.shape[0]) + yr = np.arange(0, vol_ref.shape[1]) + zr = np.arange(0, vol_ref.shape[2]) + + xrg, yrg, zrg = np.meshgrid(xr, yr, zr, indexing="ij", sparse=False) + n = xrg.size + xrg = xrg.reshape([n]) + yrg = yrg.reshape([n]) + zrg = zrg.reshape([n]) + bottom = np.ones_like(xrg) + coords = np.stack([xrg, yrg, zrg, bottom]) + coords_new = np.matmul(T, coords)[:-1, :] + result = my_interpolating_function( + (coords_new[0, :], coords_new[1, :], coords_new[2, :]) + ) + + return result.reshape(vol_ref.shape) + + +def get_ras_axes(aff, n_dims=3): + """This function finds the RAS axes corresponding to each dimension of a volume, based on its affine matrix. + :param aff: affine matrix Can be a 2d numpy array of size n_dims*n_dims, n_dims+1*n_dims+1, or n_dims*n_dims+1. + :param n_dims: number of dimensions (excluding channels) of the volume corresponding to the provided affine matrix. + :return: two numpy 1d arrays of length n_dims, one with the axes corresponding to RAS orientations, + and one with their corresponding direction. + """ + aff_inverted = np.linalg.inv(aff) + img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0) + for i in range(n_dims): + if i not in img_ras_axes: + unique, counts = np.unique(img_ras_axes, return_counts=True) + incorrect_value = unique[np.argmax(counts)] + img_ras_axes[np.where(img_ras_axes == incorrect_value)[0][-1]] = i + + return img_ras_axes + + +def align_volume_to_ref( + volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True +): + """This function aligns a volume to a reference orientation (axis and direction) specified by an affine matrix. + :param volume: a numpy array + :param aff: affine matrix of the floating volume + :param aff_ref: (optional) affine matrix of the target orientation. Default is identity matrix. + :param return_aff: (optional) whether to return the affine matrix of the aligned volume + :param n_dims: (optional) number of dimensions (excluding channels) of the volume. If not provided, n_dims will be + inferred from the input volume. + :param return_copy: (optional) whether to return the original volume or a copy. Default is copy. + :return: aligned volume, with corresponding affine matrix if return_aff is True. + """ + + # work on copy + new_volume = volume.copy() if return_copy else volume + aff_flo = aff.copy() + + # default value for aff_ref + if aff_ref is None: + aff_ref = np.eye(4) + + # extract ras axes + if n_dims is None: + n_dims, _ = utils.get_dims(new_volume.shape) + ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims) + ras_axes_flo = get_ras_axes(aff_flo, n_dims=n_dims) + + # align axes + aff_flo[:, ras_axes_ref] = aff_flo[:, ras_axes_flo] + for i in range(n_dims): + if ras_axes_flo[i] != ras_axes_ref[i]: + new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i]) + swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i]) + ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ( + ras_axes_flo[i], + ras_axes_flo[swapped_axis_idx], + ) + + # align directions + dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0) + for i in range(n_dims): + if dot_products[i] < 0: + new_volume = np.flip(new_volume, axis=i) + aff_flo[:, i] = -aff_flo[:, i] + aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1) + + if return_aff: + return new_volume, aff_flo + else: + return new_volume + + +def blur_volume(volume, sigma, mask=None): + """Blur volume with gaussian masks of given sigma. + :param volume: 2d or 3d numpy array + :param sigma: standard deviation of the gaussian kernels. Can be a number, a sequence or a 1d numpy array + :param mask: (optional) numpy array of the same shape as volume to correct for edge blurring effects. + Mask can be a boolean or numerical array. In the latter, the mask is computed by keeping all values above zero. + :return: blurred volume + """ + + # initialisation + new_volume = volume.copy() + n_dims, _ = utils.get_dims(new_volume.shape) + sigma = utils.reformat_to_list(sigma, length=n_dims, dtype="float") + + # blur image + new_volume = gaussian_filter( + new_volume, sigma=sigma, mode="nearest" + ) # nearest refers to edge padding + + # correct edge effect if mask is not None + if mask is not None: + assert new_volume.shape == mask.shape, ( + "volume and mask should have the same dimensions: " + "got {0} and {1}".format(new_volume.shape, mask.shape) + ) + mask = (mask > 0) * 1.0 + blurred_mask = gaussian_filter(mask, sigma=sigma, mode="nearest") + new_volume = new_volume / (blurred_mask + 1e-6) + new_volume[mask == 0] = 0 + + return new_volume + + +# --------------------------------------------------- edit label map --------------------------------------------------- + + +def correct_label_map( + labels, + list_incorrect_labels, + list_correct_labels=None, + use_nearest_label=False, + remove_zero=False, + smooth=False, +): + """This function corrects specified label values in a label map by either a list of given values, or by the nearest + label. + :param labels: a 2d or 3d label map + :param list_incorrect_labels: list of all label values to correct (eg [1, 2, 3]). Can also be a path to such a list. + :param list_correct_labels: (optional) list of correct label values to replace the incorrect ones. + Correct values must have the same order as their corresponding value in list_incorrect_labels. + When several correct values are possible for the same incorrect value, the nearest correct value will be selected at + each voxel to correct. In that case, the different correct values must be specified inside a list within + list_correct_labels (e.g. [10, 20, 30, [40, 50]). + :param use_nearest_label: (optional) whether to correct the incorrect label values with the nearest labels. + :param remove_zero: (optional) if use_nearest_label is True, set to True not to consider zero among the potential + candidates for the nearest neighbour. -1 will be returned when no solution are possible. + :param smooth: (optional) whether to smooth the corrected label map + :return: corrected label map + """ + + assert ( + list_correct_labels is not None + ) | use_nearest_label, ( + "please provide a list of correct labels, or set use_nearest_label to True." + ) + assert (list_correct_labels is None) | ( + not use_nearest_label + ), "cannot provide a list of correct values and set use_nearest_label to True" + + # initialisation + new_labels = labels.copy() + list_incorrect_labels = utils.reformat_to_list( + utils.load_array_if_path(list_incorrect_labels) + ) + volume_labels = np.unique(labels) + n_dims, _ = utils.get_dims(labels.shape) + + # use list of correct values + if list_correct_labels is not None: + list_correct_labels = utils.reformat_to_list( + utils.load_array_if_path(list_correct_labels) + ) + + # loop over label values + for incorrect_label, correct_label in zip( + list_incorrect_labels, list_correct_labels + ): + if incorrect_label in volume_labels: + + # only one possible value to replace with + if isinstance( + correct_label, (int, float, np.int64, np.int32, np.int16, np.int8) + ): + incorrect_voxels = np.where(labels == incorrect_label) + new_labels[incorrect_voxels] = correct_label + + # several possibilities + elif isinstance(correct_label, (tuple, list)): + + # make sure at least one correct label is present + if not any([lab in volume_labels for lab in correct_label]): + print( + "no correct values found in volume, please adjust: " + "incorrect: {}, correct: {}".format( + incorrect_label, correct_label + ) + ) + + # crop around incorrect label until we find incorrect labels + correct_label_not_found = True + margin_mult = 1 + tmp_labels = None + crop = None + while correct_label_not_found: + tmp_labels, crop = crop_volume_around_region( + labels, + masking_labels=incorrect_label, + margin=10 * margin_mult, + ) + correct_label_not_found = not any( + [lab in np.unique(tmp_labels) for lab in correct_label] + ) + margin_mult += 1 + + # calculate distance maps for all new label candidates + incorrect_voxels = np.where(tmp_labels == incorrect_label) + distance_map_list = [ + distance_transform_edt(tmp_labels != lab) + for lab in correct_label + ] + distances_correct = np.stack( + [dist[incorrect_voxels] for dist in distance_map_list] + ) + + # select nearest values and use them to correct label map + idx_correct_lab = np.argmin(distances_correct, axis=0) + incorrect_voxels = tuple( + [incorrect_voxels[i] + crop[i] for i in range(n_dims)] + ) + new_labels[incorrect_voxels] = np.array(correct_label)[ + idx_correct_lab + ] + + # use nearest label + else: + + # loop over label values + for incorrect_label in list_incorrect_labels: + if incorrect_label in volume_labels: + + # loop around regions + components, n_components = scipy_label(labels == incorrect_label) + loop_info = utils.LoopInfo(n_components + 1, 100, "correcting") + for i in range(1, n_components + 1): + loop_info.update(i) + + # crop each region + _, crop = crop_volume_around_region( + components, masking_labels=i, margin=1 + ) + tmp_labels = crop_volume_with_idx(labels, crop) + tmp_new_labels = crop_volume_with_idx(new_labels, crop) + + # list all possible correct labels + correct_labels = np.unique(tmp_labels) + for il in list_incorrect_labels: + correct_labels = np.delete( + correct_labels, np.where(correct_labels == il) + ) + if remove_zero: + correct_labels = np.delete( + correct_labels, np.where(correct_labels == 0) + ) + + # replace incorrect voxels by new value + incorrect_voxels = np.where(tmp_labels == incorrect_label) + if len(correct_labels) == 0: + tmp_new_labels[incorrect_voxels] = -1 + else: + if len(correct_labels) == 1: + idx_correct_lab = np.zeros( + len(incorrect_voxels[0]), dtype="int32" + ) + else: + distance_map_list = [ + distance_transform_edt(tmp_labels != lab) + for lab in correct_labels + ] + distances_correct = np.stack( + [dist[incorrect_voxels] for dist in distance_map_list] + ) + idx_correct_lab = np.argmin(distances_correct, axis=0) + tmp_new_labels[incorrect_voxels] = np.array(correct_labels)[ + idx_correct_lab + ] + + # paste back + if n_dims == 2: + new_labels[crop[0] : crop[2], crop[1] : crop[3], ...] = ( + tmp_new_labels + ) + else: + new_labels[ + crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ... + ] = tmp_new_labels + + # smoothing + if smooth: + kernel = np.ones(tuple([3] * n_dims)) + new_labels = smooth_label_map(new_labels, kernel) + + return new_labels + + +def mask_label_map(labels, masking_values, masking_value=0, return_mask=False): + """ + This function masks a label map around a list of specified values. + :param labels: input label map + :param masking_values: list of values to mask around + :param masking_value: (optional) value to mask the label map with + :param return_mask: (optional) whether to return the applied mask + :return: the masked label map, and the applied mask if return_mask is True. + """ + + # build mask and mask labels + mask = np.zeros(labels.shape, dtype=bool) + masked_labels = labels.copy() + for value in utils.reformat_to_list(masking_values): + mask = mask | (labels == value) + masked_labels[np.logical_not(mask)] = masking_value + + if return_mask: + mask = mask * 1 + return masked_labels, mask + else: + return masked_labels + + +def smooth_label_map(labels, kernel, labels_list=None, print_progress=0): + """This function smooth an input label map by replacing each voxel by the value of its most numerous neighbour. + :param labels: input label map + :param kernel: kernel when counting neighbours. Must contain only zeros or ones. + :param labels_list: list of label values to smooth. Defaults is None, where all labels are smoothed. + :param print_progress: (optional) If not 0, interval at which to print the number of processed labels. + :return: smoothed label map + """ + # get info + labels_shape = labels.shape + unique_labels = np.unique(labels).astype("int32") + if labels_list is None: + labels_list = unique_labels + new_labels = mask_new_labels = None + else: + labels_to_keep = [lab for lab in unique_labels if lab not in labels_list] + new_labels, mask_new_labels = mask_label_map( + labels, labels_to_keep, return_mask=True + ) + + # loop through label values + count = np.zeros(labels_shape) + labels_smoothed = np.zeros(labels_shape, dtype="int") + loop_info = utils.LoopInfo(len(labels_list), print_progress, "smoothing") + for la, label in enumerate(labels_list): + if print_progress: + loop_info.update(la) + + # count neighbours with same value + mask = (labels == label) * 1 + n_neighbours = convolve(mask, kernel) + + # update label map and maximum neighbour counts + idx = n_neighbours > count + count[idx] = n_neighbours[idx] + labels_smoothed[idx] = label + labels_smoothed = labels_smoothed.astype("int32") + + if new_labels is None: + new_labels = labels_smoothed + else: + new_labels = np.where(mask_new_labels, new_labels, labels_smoothed) + + return new_labels + + +def erode_label_map( + labels, + labels_to_erode, + erosion_factors=1.0, + gpu=False, + model=None, + return_model=False, +): + """Erode a given set of label values within a label map. + :param labels: a 2d or 3d label map + :param labels_to_erode: list of label values to erode + :param erosion_factors: (optional) list of erosion factors to use for each label. If values are integers, normal + erosion applies. If float, we first 1) blur a mask of the corresponding label value, and 2) use the erosion factor + as a threshold in the blurred mask. + If erosion_factors is a single value, the same factor will be applied to all labels. + :param gpu: (optional) whether to use a fast gpu model for blurring (if erosion factors are floats) + :param model: (optional) gpu model for blurring masks (if erosion factors are floats) + :param return_model: (optional) whether to return the gpu blurring model + :return: eroded label map, and gpu blurring model is return_model is True. + """ + # reformat labels_to_erode and erode + new_labels = labels.copy() + labels_to_erode = utils.reformat_to_list(labels_to_erode) + erosion_factors = utils.reformat_to_list( + erosion_factors, length=len(labels_to_erode) + ) + labels_shape = list(new_labels.shape) + n_dims, _ = utils.get_dims(labels_shape) + + # loop over labels to erode + for label_to_erode, erosion_factor in zip(labels_to_erode, erosion_factors): + + assert ( + erosion_factor > 0 + ), "all erosion factors should be strictly positive, had {}".format( + erosion_factor + ) + + # get mask of current label value + mask = new_labels == label_to_erode + + # erode as usual if erosion factor is int + if int(erosion_factor) == erosion_factor: + erode_struct = utils.build_binary_structure(int(erosion_factor), n_dims) + eroded_mask = binary_erosion(mask, erode_struct) + + # blur mask and use erosion factor as a threshold if float + else: + if gpu: + if model is None: + mask_in = KL.Input(shape=labels_shape + [1], dtype="float32") + blurred_mask = GaussianBlur([1] * 3)(mask_in) + model = Model(inputs=mask_in, outputs=blurred_mask) + eroded_mask = model.predict( + utils.add_axis(np.float32(mask), axis=[0, -1]) + ) + else: + eroded_mask = blur_volume(np.array(mask, dtype="float32"), 1) + eroded_mask = np.squeeze(eroded_mask) > erosion_factor + + # crop label map and mask around values to change + mask = mask & np.logical_not(eroded_mask) + cropped_lab_mask, cropping = crop_volume_around_region(mask, margin=3) + cropped_labels = crop_volume_with_idx(new_labels, cropping) + + # calculate distance maps for all labels in cropped_labels + labels_list = np.unique(cropped_labels) + labels_list = labels_list[labels_list != label_to_erode] + list_dist_maps = [ + distance_transform_edt(np.logical_not(cropped_labels == la)) + for la in labels_list + ] + candidate_distances = np.stack( + [dist[cropped_lab_mask] for dist in list_dist_maps] + ) + + # select nearest value and put cropped labels back to full label map + idx_correct_lab = np.argmin(candidate_distances, axis=0) + cropped_labels[cropped_lab_mask] = np.array(labels_list)[idx_correct_lab] + if n_dims == 2: + new_labels[cropping[0] : cropping[2], cropping[1] : cropping[3], ...] = ( + cropped_labels + ) + elif n_dims == 3: + new_labels[ + cropping[0] : cropping[3], + cropping[1] : cropping[4], + cropping[2] : cropping[5], + ..., + ] = cropped_labels + + if return_model: + return new_labels, model + else: + return new_labels + + +def get_largest_connected_component(mask, structure=None): + """Function to get the largest connected component for a given input. + :param mask: a 2d or 3d label map of boolean type. + :param structure: numpy array defining the connectivity. + """ + components, n_components = scipy_label(mask, structure) + return ( + components == np.argmax(np.bincount(components.flat)[1:]) + 1 + if n_components > 0 + else mask.copy() + ) + + +def compute_hard_volumes( + labels, voxel_volume=1.0, label_list=None, skip_background=True +): + """Compute hard volumes in a label map. + :param labels: a label map + :param voxel_volume: (optional) volume of voxel. Default is 1 (i.e. returned volumes are voxel counts). + :param label_list: (optional) list of labels to compute volumes for. Can be an int, a sequence, or a numpy array. + If None, the volumes of all label values are computed. + :param skip_background: (optional) whether to skip computing the volume of the background. + If label_list is None, this assumes background value is 0. + If label_list is not None, this assumes the background is the first value in label list. + :return: numpy 1d vector with the volumes of each structure + """ + + # initialisation + subject_label_list = utils.reformat_to_list(np.unique(labels), dtype="int") + if label_list is None: + label_list = subject_label_list + else: + label_list = utils.reformat_to_list(label_list) + if skip_background: + label_list = label_list[1:] + volumes = np.zeros(len(label_list)) + + # loop over label values + for idx, label in enumerate(label_list): + if label in subject_label_list: + mask = (labels == label) * 1 + volumes[idx] = np.sum(mask) + else: + volumes[idx] = 0 + + return volumes * voxel_volume + + +def compute_distance_map(labels, masking_labels=None, crop_margin=None): + """Compute distance map for a given list of label values in a label map. + :param labels: a label map + :param masking_labels: (optional) list of label values to mask the label map with. The distances will be computed + for these labels only. Default is None, where all positive values are considered. + :param crop_margin: (optional) margin with which to crop the input label maps around the labels for which we + want to compute the distance maps. + :return: a distance map with positive values inside the considered regions, and negative values outside. + """ + + n_dims, _ = utils.get_dims(labels.shape) + + # crop label map if necessary + if crop_margin is not None: + tmp_labels, crop_idx = crop_volume_around_region(labels, margin=crop_margin) + else: + tmp_labels = labels + crop_idx = None + + # mask label map around specify values + if masking_labels is not None: + masking_labels = utils.reformat_to_list(masking_labels) + mask = np.zeros(tmp_labels.shape, dtype="bool") + for masking_label in masking_labels: + mask = mask | tmp_labels == masking_label + else: + mask = tmp_labels > 0 + not_mask = np.logical_not(mask) + + # compute distances + dist_in = distance_transform_edt(mask) + dist_in = np.where(mask, dist_in - 0.5, dist_in) + dist_out = -distance_transform_edt(not_mask) + dist_out = np.where(not_mask, dist_out + 0.5, dist_out) + tmp_dist = dist_in + dist_out + + # put back in original matrix if we cropped + if crop_idx is not None: + dist = np.min(tmp_dist) * np.ones(labels.shape, dtype="float32") + if n_dims == 3: + dist[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] = tmp_dist + elif n_dims == 2: + dist[crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ...] = tmp_dist + else: + dist = tmp_dist + + return dist + + +# ------------------------------------------------- edit volumes in dir ------------------------------------------------ + + +def mask_images_in_dir( + image_dir, + result_dir, + mask_dir=None, + threshold=0.1, + dilate=0, + erode=0, + fill_holes=False, + masking_value=0, + write_mask=False, + mask_result_dir=None, + recompute=True, +): + """Mask all volumes in a folder, either with masks in a specified folder, or by keeping only the intensity values + above a specified threshold. + :param image_dir: path of directory with images to mask + :param result_dir: path of directory where masked images will be writen + :param mask_dir: (optional) path of directory containing masks. Masks are matched to images by sorting order. + Mask volumes don't have to be boolean or 0/1 arrays as all strictly positive values are used to build the masks. + Masks should have the same size as images. If images are multi-channel, masks can either be uni- or multi-channel. + In the first case, the same mask is applied to all channels. + :param threshold: (optional) If mask is None, masking is performed by keeping thresholding the input. + :param dilate: (optional) number of voxels by which to dilate the provided or computed masks. + :param erode: (optional) number of voxels by which to erode the provided or computed masks. + :param fill_holes: (optional) whether to fill the holes in the provided or computed masks. + :param masking_value: (optional) masking value + :param write_mask: (optional) whether to write the applied masks + :param mask_result_dir: (optional) path of resulting masks, if write_mask is True + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + if mask_result_dir is not None: + utils.mkdir(mask_result_dir) + + # get path masks if necessary + path_images = utils.list_images_in_folder(image_dir) + if mask_dir is not None: + path_masks = utils.list_images_in_folder(mask_dir) + else: + path_masks = [None] * len(path_images) + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, "masking", True) + for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): + loop_info.update(idx) + + # mask images + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + if path_mask is not None: + mask = utils.load_volume(path_mask) + else: + mask = None + im = mask_volume( + im, + mask, + threshold, + dilate, + erode, + fill_holes, + masking_value, + write_mask, + ) + + # write mask if necessary + if write_mask: + assert ( + mask_result_dir is not None + ), "if write_mask is True, mask_result_dir has to be specified as well" + mask_result_path = os.path.join( + mask_result_dir, os.path.basename(path_image) + ) + utils.save_volume(im[1], aff, h, mask_result_path) + utils.save_volume(im[0], aff, h, path_result) + else: + utils.save_volume(im, aff, h, path_result) + + +def rescale_images_in_dir( + image_dir, + result_dir, + new_min=0, + new_max=255, + min_percentile=2, + max_percentile=98, + use_positive_only=True, + recompute=True, +): + """This function linearly rescales all volumes in image_dir between new_min and new_max. + :param image_dir: path of directory with images to rescale + :param result_dir: path of directory where rescaled images will be writen + :param new_min: (optional) minimum value for the rescaled images. + :param new_max: (optional) maximum value for the rescaled images. + :param min_percentile: (optional) percentile for estimating robust minimum of volume (float in [0,...100]), + where 0 = np.min + :param max_percentile: (optional) percentile for estimating robust maximum of volume (float in [0,...100]), + where 100 = np.max + :param use_positive_only: (optional) whether to use only positive values when estimating the min and max percentile + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # loop over images + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, "rescaling", True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + im = rescale_volume( + im, new_min, new_max, min_percentile, max_percentile, use_positive_only + ) + utils.save_volume(im, aff, h, path_result) + + +def crop_images_in_dir( + image_dir, result_dir, cropping_margin=None, cropping_shape=None, recompute=True +): + """Crop all volumes in a folder by a given margin, or to a given shape. + :param image_dir: path of directory with images to rescale + :param result_dir: path of directory where cropped images will be writen + :param cropping_margin: (optional) margin by which to crop the volume. + Can be an int, a sequence or a 1d numpy array. Should be given if cropping_shape is None. + :param cropping_shape: (optional) shape to which the volume will be cropped. + Can be an int, a sequence or a 1d numpy array. Should be given if cropping_margin is None. + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # loop over images and masks + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # crop image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + volume, aff, h = utils.load_volume(path_image, im_only=False) + volume, aff = crop_volume(volume, cropping_margin, cropping_shape, aff) + utils.save_volume(volume, aff, h, path_result) + + +def crop_images_around_region_in_dir( + image_dir, + result_dir, + mask_dir=None, + threshold=0.1, + masking_labels=None, + crop_margin=5, + recompute=True, +): + """Crop all volumes in a folder around a region, which is defined for each volume by a mask obtained by either + 1) directly providing it as input + 2) thresholding the input volume + 3) keeping a set of label values if the volume is a label map. + :param image_dir: path of directory with images to crop + :param result_dir: path of directory where cropped images will be writen + :param mask_dir: (optional) path of directory of input masks + :param threshold: (optional) lower bound to determine values to crop around + :param masking_labels: (optional) if the volume is a label map, it can be cropped around a given set of labels by + specifying them in masking_labels, which can either be a single int, a list or a 1d numpy array. + :param crop_margin: (optional) cropping margin + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # list volumes and masks + path_images = utils.list_images_in_folder(image_dir) + if mask_dir is not None: + path_masks = utils.list_images_in_folder(mask_dir) + else: + path_masks = [None] * len(path_images) + + # loop over images and masks + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) + for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): + loop_info.update(idx) + + # crop image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + volume, aff, h = utils.load_volume(path_image, im_only=True) + if path_mask is not None: + mask = utils.load_volume(path_mask) + else: + mask = None + volume, cropping, aff = crop_volume_around_region( + volume, mask, threshold, masking_labels, crop_margin, aff + ) + utils.save_volume(volume, aff, h, path_result) + + +def pad_images_in_dir( + image_dir, result_dir, max_shape=None, padding_value=0, recompute=True +): + """Pads all the volumes in a folder to the same shape (either provided or computed). + :param image_dir: path of directory with images to pad + :param result_dir: path of directory where padded images will be writen + :param max_shape: (optional) shape to pad the volumes to. Can be an int, a sequence or a 1d numpy array. + If None, volumes will be padded to the shape of the biggest volume in image_dir. + :param padding_value: (optional) value to pad the volumes with. + :param recompute: (optional) whether to recompute result files even if they already exist + :return: shape of the padded volumes. + """ + + # create result dir + utils.mkdir(result_dir) + + # list labels + path_images = utils.list_images_in_folder(image_dir) + + # get maximum shape + if max_shape is None: + max_shape, aff, _, _, h, _ = utils.get_volume_info(path_images[0]) + for path_image in path_images[1:]: + image_shape, aff, _, _, h, _ = utils.get_volume_info(path_image) + max_shape = tuple( + np.maximum(np.asarray(max_shape), np.asarray(image_shape)) + ) + max_shape = np.array(max_shape) + + # loop over label maps + loop_info = utils.LoopInfo(len(path_images), 10, "padding", True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # pad map + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + im, aff = pad_volume(im, max_shape, padding_value, aff) + utils.save_volume(im, aff, h, path_result) + + return max_shape + + +def flip_images_in_dir( + image_dir, result_dir, axis=None, direction=None, recompute=True +): + """Flip all images in a directory along a specified axis. + If unknown, this axis can be replaced by an anatomical direction. + :param image_dir: path of directory with images to flip + :param result_dir: path of directory where flipped images will be writen + :param axis: (optional) axis along which to flip the volume + :param direction: (optional) if axis is None, the volume can be flipped along an anatomical direction: + 'rl' (right/left), 'ap' (anterior/posterior), 'si' (superior/inferior). + :param recompute: (optional) whether to recompute result files even if they already exists + """ + # create result dir + utils.mkdir(result_dir) + + # loop over images + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, "flipping", True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # flip image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + im = flip_volume(im, axis=axis, direction=direction, aff=aff) + utils.save_volume(im, aff, h, path_result) + + +def align_images_in_dir( + image_dir, result_dir, aff_ref=None, path_ref=None, recompute=True +): + """This function aligns all images in image_dir to a reference orientation (axes and directions). + This reference orientation can be directly provided as an affine matrix, or can be specified by a reference volume. + If neither are provided, the reference orientation is assumed to be an identity matrix. + :param image_dir: path of directory with images to align + :param result_dir: path of directory where flipped images will be writen + :param aff_ref: (optional) reference affine matrix. Can be a numpy array, or the path to such array. + :param path_ref: (optional) path of a volume to which all images will be aligned. Can also be the path to a folder + with as many images as in image_dir, in which case each image in image_dir is aligned to its counterpart in path_ref + (they are matched by sorting order). + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + path_images = utils.list_images_in_folder(image_dir) + + # read reference affine matrix + if path_ref is not None: + assert aff_ref is None, "cannot provide aff_ref and path_ref together." + basename = os.path.basename(path_ref) + if ( + (".nii.gz" in basename) + | (".nii" in basename) + | (".mgz" in basename) + | (".npz" in basename) + ): + _, aff_ref, _ = utils.load_volume(path_ref, im_only=False) + path_refs = [None] * len(path_images) + else: + path_refs = utils.list_images_in_folder(path_ref) + elif aff_ref is not None: + aff_ref = utils.load_array_if_path(aff_ref) + path_refs = [None] * len(path_images) + else: + aff_ref = np.eye(4) + path_refs = [None] * len(path_images) + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, "aligning", True) + for idx, (path_image, path_ref) in enumerate(zip(path_images, path_refs)): + loop_info.update(idx) + + # align image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + if path_ref is not None: + _, aff_ref, _ = utils.load_volume(path_ref, im_only=False) + im, aff = align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=True) + utils.save_volume(im, aff, h, path_result) + + +def correct_nans_images_in_dir(image_dir, result_dir, recompute=True): + """Correct NaNs in all images in a directory. + :param image_dir: path of directory with images to correct + :param result_dir: path of directory where corrected images will be writen + :param recompute: (optional) whether to recompute result files even if they already exists + """ + # create result dir + utils.mkdir(result_dir) + + # loop over images + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, "correcting", True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # flip image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + im[np.isnan(im)] = 0 + utils.save_volume(im, aff, h, path_result) + + +def blur_images_in_dir( + image_dir, result_dir, sigma, mask_dir=None, gpu=False, recompute=True +): + """This function blurs all the images in image_dir with kernels of the specified std deviations. + :param image_dir: path of directory with images to blur + :param result_dir: path of directory where blurred images will be writen + :param sigma: standard deviation of the blurring gaussian kernels. + Can be a number (isotropic blurring), or a sequence with the same length as the number of dimensions of images. + :param mask_dir: (optional) path of directory with masks of the region to blur. + Images and masks are matched by sorting order. + :param gpu: (optional) whether to use a fast gpu model for blurring + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # list images and masks + path_images = utils.list_images_in_folder(image_dir) + if mask_dir is not None: + path_masks = utils.list_images_in_folder(mask_dir) + else: + path_masks = [None] * len(path_images) + + # loop over images + previous_model_input_shape = None + model = None + loop_info = utils.LoopInfo(len(path_images), 10, "blurring", True) + for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): + loop_info.update(idx) + + # load image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, im_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_image, return_volume=True + ) + if path_mask is not None: + mask = utils.load_volume(path_mask) + assert ( + mask.shape == im.shape + ), "mask and image should have the same shape" + else: + mask = None + + # blur image + if gpu: + if (im_shape != previous_model_input_shape) | (model is None): + previous_model_input_shape = im_shape + inputs = [KL.Input(shape=im_shape + [1])] + sigma = utils.reformat_to_list(sigma, length=n_dims) + if mask is None: + image = GaussianBlur(sigma=sigma)(inputs[0]) + else: + inputs.append(KL.Input(shape=im_shape + [1], dtype="float32")) + image = GaussianBlur(sigma=sigma, use_mask=True)(inputs) + model = Model(inputs=inputs, outputs=image) + if mask is None: + im = np.squeeze(model.predict(utils.add_axis(im, axis=[0, -1]))) + else: + im = np.squeeze( + model.predict( + [utils.add_axis(im, [0, -1]), utils.add_axis(mask, [0, -1])] + ) + ) + else: + im = blur_volume(im, sigma, mask=mask) + utils.save_volume(im, aff, h, path_result) + + +def create_mutlimodal_images(list_channel_dir, result_dir, recompute=True): + """This function forms multimodal images by stacking channels located in different folders. + :param list_channel_dir: list of all directories, each containing the same channel for all images. + Channels are matched between folders by sorting order. + :param result_dir: path of directory where multimodal images will be writen + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + assert isinstance( + list_channel_dir, (list, tuple) + ), "list_channel_dir should be a list or a tuple" + + # gather path of all images for all channels + list_channel_paths = [utils.list_images_in_folder(d) for d in list_channel_dir] + n_images = len(list_channel_paths[0]) + n_channels = len(list_channel_dir) + for channel_paths in list_channel_paths: + if len(channel_paths) != n_images: + raise ValueError("all directories should have the same number of files") + + # loop over images + loop_info = utils.LoopInfo(n_images, 10, "processing", True) + for idx in range(n_images): + loop_info.update(idx) + + # stack all channels and save multichannel image + path_result = os.path.join( + result_dir, os.path.basename(list_channel_paths[0][idx]) + ) + if (not os.path.isfile(path_result)) | recompute: + list_channels = list() + tmp_aff = None + tmp_h = None + for channel_idx in range(n_channels): + tmp_channel, tmp_aff, tmp_h = utils.load_volume( + list_channel_paths[channel_idx][idx], im_only=False + ) + list_channels.append(tmp_channel) + im = np.stack(list_channels, axis=-1) + utils.save_volume(im, tmp_aff, tmp_h, path_result) + + +def convert_images_in_dir_to_nifty( + image_dir, result_dir, aff=None, ref_aff_dir=None, recompute=True +): + """Converts all images in image_dir to nifty format. + :param image_dir: path of directory with images to convert + :param result_dir: path of directory where converted images will be writen + :param aff: (optional) affine matrix in homogeneous coordinates with which to write the images. + Can also be 'FS' to write images with FreeSurfer typical affine matrix. + :param ref_aff_dir: (optional) alternatively to providing a fixed aff, different affine matrices can be used for + each image in image_dir by matching them to corresponding volumes contained in ref_aff_dir. + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # list images + path_images = utils.list_images_in_folder(image_dir) + if ref_aff_dir is not None: + path_ref_images = utils.list_images_in_folder(ref_aff_dir) + else: + path_ref_images = [None] * len(path_images) + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, "converting", True) + for idx, (path_image, path_ref) in enumerate(zip(path_images, path_ref_images)): + loop_info.update(idx) + + # convert images to nifty format + path_result = ( + os.path.join( + result_dir, os.path.basename(utils.strip_extension(path_image)) + ) + + ".nii.gz" + ) + if (not os.path.isfile(path_result)) | recompute: + if utils.get_image_extension(path_image) == "nii.gz": + shutil.copy2(path_image, path_result) + else: + im, tmp_aff, h = utils.load_volume(path_image, im_only=False) + if aff is not None: + tmp_aff = aff + elif path_ref is not None: + _, tmp_aff, h = utils.load_volume(path_ref, im_only=False) + utils.save_volume(im, tmp_aff, h, path_result) + + +def mri_convert_images_in_dir( + image_dir, + result_dir, + interpolation=None, + reference_dir=None, + same_reference=False, + voxsize=None, + path_freesurfer="/usr/local/freesurfer", + mri_convert_path="/usr/local/freesurfer/bin/mri_convert", + recompute=True, +): + """This function launches mri_convert on all images contained in image_dir, and writes the results in result_dir. + The interpolation type can be specified (i.e. 'nearest'), as well as a folder containing references for resampling. + reference_dir can be the path of a single *image* if same_reference=True. + :param image_dir: path of directory with images to convert + :param result_dir: path of directory where converted images will be writen + :param interpolation: (optional) interpolation type, can be 'inter' (default), 'cubic', 'nearest', 'trilinear' + :param reference_dir: (optional) path of directory with reference images. References are matched to images by + sorting order. If same_reference is false, references and images are matched by sorting order. + This can also be the path to a single image that will be used as reference for all images im image_dir (set + same_reference to true in that case). + :param same_reference: (optional) whether to use a single image as reference for all images to interpolate. + :param voxsize: (optional) resolution at which to resample converted image. Must be a list of length n_dims. + :param path_freesurfer: (optional) path FreeSurfer home + :param mri_convert_path: (optional) path mri_convert binary file + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # set up FreeSurfer + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = mri_convert_path + " " + + # list images + path_images = utils.list_images_in_folder(image_dir) + if reference_dir is not None: + if same_reference: + path_references = [reference_dir] * len(path_images) + else: + path_references = utils.list_images_in_folder(reference_dir) + assert len(path_references) == len( + path_images + ), "different number of files in image_dir and reference_dir" + else: + path_references = [None] * len(path_images) + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, "converting", True) + for idx, (path_image, path_reference) in enumerate( + zip(path_images, path_references) + ): + loop_info.update(idx) + + # convert image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + cmd = mri_convert + path_image + " " + path_result + " -odt float" + if interpolation is not None: + cmd += " -rt " + interpolation + if reference_dir is not None: + cmd += " -rl " + path_reference + if voxsize is not None: + voxsize = utils.reformat_to_list(voxsize, dtype="float") + cmd += " --voxsize " + " ".join([str(np.around(v, 3)) for v in voxsize]) + os.system(cmd) + + +def samseg_images_in_dir( + image_dir, + result_dir, + atlas_dir=None, + threads=4, + path_freesurfer="/usr/local/freesurfer", + keep_segm_only=True, + recompute=True, +): + """This function launches samseg for all images contained in image_dir and writes the results in result_dir. + If keep_segm_only=True, the result segmentation is copied in result_dir and SAMSEG's intermediate result dir is + deleted. + :param image_dir: path of directory with input images + :param result_dir: path of directory where processed images folders (if keep_segm_only is False), + or samseg segmentation (if keep_segm_only is True) will be writen + :param atlas_dir: (optional) path of samseg atlas directory. If None, use samseg default atlas. + :param threads: (optional) number of threads to use + :param path_freesurfer: (optional) path FreeSurfer home + :param keep_segm_only: (optional) whether to keep samseg result folders, or only samseg segmentations. + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # set up FreeSurfer + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + path_samseg = os.path.join(path_freesurfer, "bin", "run_samseg") + + # loop over images + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # build path_result + path_im_result_dir = os.path.join( + result_dir, utils.strip_extension(os.path.basename(path_image)) + ) + path_samseg_result = os.path.join(path_im_result_dir, "seg.mgz") + if keep_segm_only: + path_result = os.path.join( + result_dir, + utils.strip_extension(os.path.basename(path_image)) + "_seg.mgz", + ) + else: + path_result = path_samseg_result + + # run samseg + if (not os.path.isfile(path_result)) | recompute: + cmd = utils.mkcmd( + path_samseg, + "-i", + path_image, + "-o", + path_im_result_dir, + "--threads", + threads, + ) + if atlas_dir is not None: + cmd = utils.mkcmd(cmd, "-a", atlas_dir) + os.system(cmd) + + # move segmentation to result_dir if necessary + if keep_segm_only: + if os.path.isfile(path_samseg_result): + shutil.move(path_samseg_result, path_result) + if os.path.isdir(path_im_result_dir): + shutil.rmtree(path_im_result_dir) + + +def niftyreg_images_in_dir( + image_dir, + reference_dir, + nifty_reg_function="reg_resample", + input_transformation_dir=None, + result_dir=None, + result_transformation_dir=None, + interpolation=None, + same_floating=False, + same_reference=False, + same_transformation=False, + path_nifty_reg="/home/benjamin/Softwares/niftyreg-gpu/build/reg-apps", + recompute=True, +): + """This function launches one of niftyreg functions (reg_aladin, reg_f3d, reg_resample) on all images contained + in image_dir. + :param image_dir: path of directory with images to register. Can also be a single image, in that case set + same_floating to True. + :param reference_dir: path of directory with reference images. If same_reference is false, references and images are + matched by sorting order. This can also be the path to a single image that will be used as reference for all images + im image_dir (set same_reference to True in that case). + :param nifty_reg_function: (optional) name of the niftyreg function to use. Can be 'reg_aladin', 'reg_f3d', or + 'reg_resample'. Default is 'reg_resample'. + :param input_transformation_dir: (optional) path of a directory containing all the input transformation (for + reg_resample, or reg_f3d). Can also be the path to a single transformation that will be used for all images + in image_dir (set same_transformation to True in that case). + :param result_dir: path of directory where output images will be writen. + :param result_transformation_dir: path of directory where resulting transformations will be writen (for + reg_aladin and reg_f3d). + :param interpolation: (optional) integer describing the order of the interpolation to apply (0 = nearest neighbours) + :param same_floating: (optional) set to true if only one image is used as floating image. + :param same_reference: (optional) whether to use a single image as reference for all input images. + :param same_transformation: (optional) whether to apply the same transformation to all floating images. + :param path_nifty_reg: (optional) path of the folder containing nifty-reg functions + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dirs + if result_dir is not None: + utils.mkdir(result_dir) + if result_transformation_dir is not None: + utils.mkdir(result_transformation_dir) + + nifty_reg = os.path.join(path_nifty_reg, nifty_reg_function) + + # list reference and floating images + path_images = utils.list_images_in_folder(image_dir) + path_references = utils.list_images_in_folder(reference_dir) + if same_reference: + path_references = utils.reformat_to_list( + path_references, length=len(path_images) + ) + if same_floating: + path_images = utils.reformat_to_list(path_images, length=len(path_references)) + assert len(path_references) == len( + path_images + ), "different number of files in image_dir and reference_dir" + + # list input transformations + if input_transformation_dir is not None: + if same_transformation: + path_input_transfs = utils.reformat_to_list( + input_transformation_dir, length=len(path_images) + ) + else: + path_input_transfs = utils.list_files(input_transformation_dir) + assert len(path_input_transfs) == len( + path_images + ), "different number of transformations and images" + else: + path_input_transfs = [None] * len(path_images) + + # define flag input trans + if input_transformation_dir is not None: + if nifty_reg_function == "reg_aladin": + flag_input_trans = "-inaff" + elif nifty_reg_function == "reg_f3d": + flag_input_trans = "-aff" + elif nifty_reg_function == "reg_resample": + flag_input_trans = "-trans" + else: + raise Exception( + 'nifty_reg_function can only be "reg_aladin", "reg_f3d", or "reg_resample"' + ) + else: + flag_input_trans = None + + # define flag result transformation + if result_transformation_dir is not None: + if nifty_reg_function == "reg_aladin": + flag_result_trans = "-aff" + elif nifty_reg_function == "reg_f3d": + flag_result_trans = "-cpp" + else: + raise Exception( + 'result_transformation_dir can only be used with "reg_aladin" or "reg_f3d"' + ) + else: + flag_result_trans = None + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) + for idx, (path_image, path_ref, path_input_trans) in enumerate( + zip(path_images, path_references, path_input_transfs) + ): + loop_info.update(idx) + + # define path registered image + name = ( + os.path.basename(path_ref) + if same_floating + else os.path.basename(path_image) + ) + if result_dir is not None: + path_result = os.path.join(result_dir, name) + result_already_computed = os.path.isfile(path_result) + else: + path_result = None + result_already_computed = True + + # define path resulting transformation + if result_transformation_dir is not None: + if nifty_reg_function == "reg_aladin": + path_result_trans = os.path.join( + result_transformation_dir, utils.strip_extension(name) + ".txt" + ) + result_trans_already_computed = os.path.isfile(path_result_trans) + else: + path_result_trans = os.path.join(result_transformation_dir, name) + result_trans_already_computed = os.path.isfile(path_result_trans) + else: + path_result_trans = None + result_trans_already_computed = True + + if ( + (not result_already_computed) + | (not result_trans_already_computed) + | recompute + ): + + # build main command + cmd = utils.mkcmd(nifty_reg, "-ref", path_ref, "-flo", path_image, "-pad 0") + + # add options + if path_result is not None: + cmd = utils.mkcmd(cmd, "-res", path_result) + if flag_input_trans is not None: + cmd = utils.mkcmd(cmd, flag_input_trans, path_input_trans) + if flag_result_trans is not None: + cmd = utils.mkcmd(cmd, flag_result_trans, path_result_trans) + if interpolation is not None: + cmd = utils.mkcmd(cmd, "-inter", interpolation) + + # execute + os.system(cmd) + + +def upsample_anisotropic_images( + image_dir, + resample_image_result_dir, + resample_like_dir, + path_freesurfer="/usr/local/freesurfer/", + recompute=True, +): + """This function takes as input a set of LR images and resample them to HR with respect to reference images. + :param image_dir: path of directory with input images (only uni-modal images supported) + :param resample_image_result_dir: path of directory where resampled images will be writen + :param resample_like_dir: path of directory with reference images. + :param path_freesurfer: (optional) path freesurfer home, as this function uses mri_convert + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(resample_image_result_dir) + + # set up FreeSurfer + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") + + # list images and labels + path_images = utils.list_images_in_folder(image_dir) + path_ref_images = utils.list_images_in_folder(resample_like_dir) + assert len(path_images) == len( + path_ref_images + ), "the folders containing the images and their references are not the same size" + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, "upsampling", True) + for idx, (path_image, path_ref) in enumerate(zip(path_images, path_ref_images)): + loop_info.update(idx) + + # upsample image + _, _, n_dims, _, _, image_res = utils.get_volume_info( + path_image, return_volume=False + ) + path_im_upsampled = os.path.join( + resample_image_result_dir, os.path.basename(path_image) + ) + if (not os.path.isfile(path_im_upsampled)) | recompute: + cmd = utils.mkcmd( + mri_convert, + path_image, + path_im_upsampled, + "-rl", + path_ref, + "-odt float", + ) + os.system(cmd) + + path_dist_map = os.path.join( + resample_image_result_dir, "dist_map_" + os.path.basename(path_image) + ) + if (not os.path.isfile(path_dist_map)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing="ij") + tmp_dir = utils.strip_extension(path_im_upsampled) + "_meshes" + utils.mkdir(tmp_dir) + path_meshes_up = list() + for i, maps in enumerate(dist_map): + path_mesh = os.path.join( + tmp_dir, "%s_" % i + os.path.basename(path_image) + ) + path_mesh_up = os.path.join( + tmp_dir, "up_%s_" % i + os.path.basename(path_image) + ) + utils.save_volume(maps, aff, h, path_mesh) + cmd = utils.mkcmd( + mri_convert, + path_mesh, + path_mesh_up, + "-rl", + path_im_upsampled, + "-odt float", + ) + os.system(cmd) + path_meshes_up.append(path_mesh_up) + mesh_up_0, aff, h = utils.load_volume(path_meshes_up[0], im_only=False) + mesh_up = np.stack( + [mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1 + ) + shutil.rmtree(tmp_dir) + + floor = np.floor(mesh_up) + ceil = np.ceil(mesh_up) + f_dist = mesh_up - floor + c_dist = ceil - mesh_up + dist = np.minimum(f_dist, c_dist) * utils.add_axis( + image_res, axis=[0] * n_dims + ) + dist = np.sqrt(np.sum(dist**2, axis=-1)) + utils.save_volume(dist, aff, h, path_dist_map) + + +def simulate_upsampled_anisotropic_images( + image_dir, + downsample_image_result_dir, + resample_image_result_dir, + data_res, + labels_dir=None, + downsample_labels_result_dir=None, + slice_thickness=None, + build_dist_map=False, + path_freesurfer="/usr/local/freesurfer/", + gpu=True, + recompute=True, +): + """This function takes as input a set of HR images and creates two datasets with it: + 1) a set of LR images obtained by downsampling the HR images with nearest neighbour interpolation, + 2) a set of HR images obtained by resampling the LR images to native HR with linear interpolation. + Additionally, this function can also create a set of LR labels from label maps corresponding to the input images. + :param image_dir: path of directory with input images (only uni-model images supported) + :param downsample_image_result_dir: path of directory where downsampled images will be writen + :param resample_image_result_dir: path of directory where resampled images will be writen + :param data_res: resolution of LR images. Can either be: an int, a float, a list or a numpy array. + :param labels_dir: (optional) path of directory with label maps corresponding to input images + :param downsample_labels_result_dir: (optional) path of directory where downsampled label maps will be writen + :param slice_thickness: (optional) thickness of slices to simulate. Can be a number, a list or a numpy array. + :param build_dist_map: (optional) whether to return the resampled images with an additional channel indicating the + distance of each voxel to the nearest acquired voxel. Default is False. + :param path_freesurfer: (optional) path freesurfer home, as this function uses mri_convert + :param gpu: (optional) whether to use a fast gpu model for blurring + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(resample_image_result_dir) + utils.mkdir(downsample_image_result_dir) + if labels_dir is not None: + assert ( + downsample_labels_result_dir is not None + ), "downsample_labels_result_dir should not be None if labels_dir is specified" + utils.mkdir(downsample_labels_result_dir) + + # set up FreeSurfer + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") + + # list images and labels + path_images = utils.list_images_in_folder(image_dir) + path_labels = ( + [None] * len(path_images) + if labels_dir is None + else utils.list_images_in_folder(labels_dir) + ) + + # initialisation + _, _, n_dims, _, _, image_res = utils.get_volume_info( + path_images[0], return_volume=False, aff_ref=np.eye(4) + ) + data_res = np.squeeze( + utils.reformat_to_n_channels_array(data_res, n_dims, n_channels=1) + ) + slice_thickness = utils.reformat_to_list(slice_thickness, length=n_dims) + + # loop over images + previous_model_input_shape = None + model = None + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) + for idx, (path_image, path_labels) in enumerate(zip(path_images, path_labels)): + loop_info.update(idx) + + # downsample image + path_im_downsampled = os.path.join( + downsample_image_result_dir, os.path.basename(path_image) + ) + if (not os.path.isfile(path_im_downsampled)) | recompute: + im, _, aff, n_dims, _, h, image_res = utils.get_volume_info( + path_image, return_volume=True + ) + im, aff_aligned = align_volume_to_ref( + im, aff, aff_ref=np.eye(4), return_aff=True, n_dims=n_dims + ) + im_shape = list(im.shape[:n_dims]) + sigma = blurring_sigma_for_downsampling( + image_res, data_res, thickness=slice_thickness + ) + sigma = [ + 0 if data_res[i] == image_res[i] else sigma[i] for i in range(n_dims) + ] + + # blur image + if gpu: + if (im_shape != previous_model_input_shape) | (model is None): + previous_model_input_shape = im_shape + image_in = KL.Input(shape=im_shape + [1]) + image = GaussianBlur(sigma=sigma)(image_in) + model = Model(inputs=image_in, outputs=image) + im = np.squeeze(model.predict(utils.add_axis(im, axis=[0, -1]))) + else: + im = blur_volume(im, sigma, mask=None) + utils.save_volume(im, aff_aligned, h, path_im_downsampled) + + # downsample blurred image + voxsize = " ".join([str(r) for r in data_res]) + cmd = utils.mkcmd( + mri_convert, + path_im_downsampled, + path_im_downsampled, + "--voxsize", + voxsize, + "-odt float -rt nearest", + ) + os.system(cmd) + + # downsample labels if necessary + if path_labels is not None: + path_lab_downsampled = os.path.join( + downsample_labels_result_dir, os.path.basename(path_labels) + ) + if (not os.path.isfile(path_lab_downsampled)) | recompute: + cmd = utils.mkcmd( + mri_convert, + path_labels, + path_lab_downsampled, + "-rl", + path_im_downsampled, + "-odt float -rt nearest", + ) + os.system(cmd) + + # upsample image + path_im_upsampled = os.path.join( + resample_image_result_dir, os.path.basename(path_image) + ) + if (not os.path.isfile(path_im_upsampled)) | recompute: + cmd = utils.mkcmd( + mri_convert, + path_im_downsampled, + path_im_upsampled, + "-rl", + path_image, + "-odt float", + ) + os.system(cmd) + + if build_dist_map: + path_dist_map = os.path.join( + resample_image_result_dir, "dist_map_" + os.path.basename(path_image) + ) + if (not os.path.isfile(path_dist_map)) | recompute: + im, aff, h = utils.load_volume(path_im_downsampled, im_only=False) + dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing="ij") + tmp_dir = utils.strip_extension(path_im_downsampled) + "_meshes" + utils.mkdir(tmp_dir) + path_meshes_up = list() + for i, d_map in enumerate(dist_map): + path_mesh = os.path.join( + tmp_dir, "%s_" % i + os.path.basename(path_image) + ) + path_mesh_up = os.path.join( + tmp_dir, "up_%s_" % i + os.path.basename(path_image) + ) + utils.save_volume(d_map, aff, h, path_mesh) + cmd = utils.mkcmd( + mri_convert, + path_mesh, + path_mesh_up, + "-rl", + path_image, + "-odt float", + ) + os.system(cmd) + path_meshes_up.append(path_mesh_up) + mesh_up_0, aff, h = utils.load_volume(path_meshes_up[0], im_only=False) + mesh_up = np.stack( + [mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1 + ) + shutil.rmtree(tmp_dir) + + floor = np.floor(mesh_up) + ceil = np.ceil(mesh_up) + f_dist = mesh_up - floor + c_dist = ceil - mesh_up + dist = np.minimum(f_dist, c_dist) * utils.add_axis( + data_res, axis=[0] * n_dims + ) + dist = np.sqrt(np.sum(dist**2, axis=-1)) + utils.save_volume(dist, aff, h, path_dist_map) + + +def check_images_in_dir( + image_dir, check_values=False, keep_unique=True, max_channels=10, verbose=True +): + """Check if all volumes within the same folder share the same characteristics: shape, affine matrix, resolution. + Also have option to check if all volumes have the same intensity values (useful for label maps). + :return four lists, each containing the different values detected for a specific parameter among those to check. + """ + + # define information to check + list_shape = list() + list_aff = list() + list_res = list() + list_axes = list() + if check_values: + list_unique_values = list() + else: + list_unique_values = None + + # loop through files + path_images = utils.list_images_in_folder(image_dir) + loop_info = ( + utils.LoopInfo(len(path_images), 10, "checking", verbose) if verbose else None + ) + for idx, path_image in enumerate(path_images): + if loop_info is not None: + loop_info.update(idx) + + # get info + im, shape, aff, n_dims, _, h, res = utils.get_volume_info( + path_image, True, np.eye(4), max_channels + ) + axes = get_ras_axes(aff, n_dims=n_dims).tolist() + aff[:, np.arange(n_dims)] = aff[:, axes] + aff = (np.int32(np.round(np.array(aff[:3, :3]), 2) * 100) / 100).tolist() + res = (np.int32(np.round(np.array(res), 2) * 100) / 100).tolist() + + # add values to list if not already there + if (shape not in list_shape) | (not keep_unique): + list_shape.append(shape) + if (aff not in list_aff) | (not keep_unique): + list_aff.append(aff) + if (res not in list_res) | (not keep_unique): + list_res.append(res) + if (axes not in list_axes) | (not keep_unique): + list_axes.append(axes) + if list_unique_values is not None: + uni = np.unique(im).tolist() + if (uni not in list_unique_values) | (not keep_unique): + list_unique_values.append(uni) + + return list_shape, list_aff, list_res, list_axes, list_unique_values + + +# ----------------------------------------------- edit label maps in dir ----------------------------------------------- + + +def correct_labels_in_dir( + labels_dir, + results_dir, + incorrect_labels, + correct_labels=None, + use_nearest_label=False, + remove_zero=False, + smooth=False, + recompute=True, +): + """This function corrects label values for all label maps in a folder with either + - a list a given values, + - or with the nearest label value. + :param labels_dir: path of directory with input label maps + :param results_dir: path of directory where corrected label maps will be writen + :param incorrect_labels: list of all label values to correct (e.g. [1, 2, 3, 4]). + :param correct_labels: (optional) list of correct label values to replace the incorrect ones. + Correct values must have the same order as their corresponding value in list_incorrect_labels. + When several correct values are possible for the same incorrect value, the nearest correct value will be selected at + each voxel to correct. In that case, the different correct values must be specified inside a list within + list_correct_labels (e.g. [10, 20, 30, [40, 50]). + :param use_nearest_label: (optional) whether to correct the incorrect label values with the nearest labels. + :param remove_zero: (optional) if use_nearest_label is True, set to True not to consider zero among the potential + candidates for the nearest neighbour. + :param smooth: (optional) whether to smooth the corrected label maps + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(results_dir) + + # prepare data files + path_labels = utils.list_images_in_folder(labels_dir) + loop_info = utils.LoopInfo(len(path_labels), 10, "correcting", True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # correct labels + path_result = os.path.join(results_dir, os.path.basename(path_label)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_label, im_only=False, dtype="int32") + im = correct_label_map( + im, + incorrect_labels, + correct_labels, + use_nearest_label, + remove_zero, + smooth, + ) + utils.save_volume(im, aff, h, path_result) + + +def mask_labels_in_dir( + labels_dir, + result_dir, + values_to_keep, + masking_value=0, + mask_result_dir=None, + recompute=True, +): + """This function masks all label maps in a folder by keeping a set of given label values. + :param labels_dir: path of directory with input label maps + :param result_dir: path of directory where corrected label maps will be writen + :param values_to_keep: list of values for masking the label maps. + :param masking_value: (optional) value to mask the label maps with + :param mask_result_dir: (optional) path of directory where applied masks will be writen + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + if mask_result_dir is not None: + utils.mkdir(mask_result_dir) + + # reformat values to keep + values_to_keep = utils.reformat_to_list(values_to_keep, load_as_numpy=True) + + # loop over labels + path_labels = utils.list_images_in_folder(labels_dir) + loop_info = utils.LoopInfo(len(path_labels), 10, "masking", True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # mask labels + path_result = os.path.join(result_dir, os.path.basename(path_label)) + if mask_result_dir is not None: + path_result_mask = os.path.join( + mask_result_dir, os.path.basename(path_label) + ) + else: + path_result_mask = "" + if ( + (not os.path.isfile(path_result)) + | (mask_result_dir is not None) & (not os.path.isfile(path_result_mask)) + | recompute + ): + lab, aff, h = utils.load_volume(path_label, im_only=False) + if mask_result_dir is not None: + labels, mask = mask_label_map( + lab, values_to_keep, masking_value, return_mask=True + ) + path_result_mask = os.path.join( + mask_result_dir, os.path.basename(path_label) + ) + utils.save_volume(mask, aff, h, path_result_mask) + else: + labels = mask_label_map( + lab, values_to_keep, masking_value, return_mask=False + ) + utils.save_volume(labels, aff, h, path_result) + + +def smooth_labels_in_dir( + labels_dir, result_dir, gpu=False, labels_list=None, connectivity=1, recompute=True +): + """Smooth all label maps in a folder by replacing each voxel by the value of its most numerous neighbours. + :param labels_dir: path of directory with input label maps + :param result_dir: path of directory where smoothed label maps will be writen + :param gpu: (optional) whether to use a gpu implementation for faster processing + :param labels_list: (optional) if gpu is True, path of numpy array with all label values. + Automatically computed if not provided. + :param connectivity: (optional) connectivity to use when smoothing the label maps + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # list label maps + path_labels = utils.list_images_in_folder(labels_dir) + + if labels_list is not None: + labels_list, _ = utils.get_list_labels(label_list=labels_list, FS_sort=True) + + if gpu: + # initialisation + previous_model_input_shape = None + smoothing_model = None + + # loop over label maps + loop_info = utils.LoopInfo(len(path_labels), 10, "smoothing", True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # smooth label map + path_result = os.path.join(result_dir, os.path.basename(path_label)) + if (not os.path.isfile(path_result)) | recompute: + labels, label_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_label, return_volume=True + ) + if label_shape != previous_model_input_shape: + previous_model_input_shape = label_shape + smoothing_model = smoothing_gpu_model( + label_shape, labels_list, connectivity + ) + unique_labels = np.unique(labels).astype("int32") + if labels_list is None: + smoothed_labels = smoothing_model.predict(utils.add_axis(labels)) + else: + labels_to_keep = [ + lab for lab in unique_labels if lab not in labels_list + ] + new_labels, mask_new_labels = mask_label_map( + labels, labels_to_keep, return_mask=True + ) + smoothed_labels = np.squeeze( + smoothing_model.predict(utils.add_axis(labels)) + ) + smoothed_labels = np.where( + mask_new_labels, new_labels, smoothed_labels + ) + mask_new_zeros = (labels > 0) & (smoothed_labels == 0) + smoothed_labels[mask_new_zeros] = labels[mask_new_zeros] + utils.save_volume(smoothed_labels, aff, h, path_result, dtype="int32") + + else: + # build kernel + _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) + kernel = utils.build_binary_structure(connectivity, n_dims, shape=n_dims) + + # loop over label maps + loop_info = utils.LoopInfo(len(path_labels), 10, "smoothing", True) + for idx, path in enumerate(path_labels): + loop_info.update(idx) + + # smooth label map + path_result = os.path.join(result_dir, os.path.basename(path)) + if (not os.path.isfile(path_result)) | recompute: + volume, aff, h = utils.load_volume(path, im_only=False) + new_volume = smooth_label_map(volume, kernel, labels_list) + utils.save_volume(new_volume, aff, h, path_result, dtype="int32") + + +def smoothing_gpu_model(label_shape, label_list, connectivity=1): + """This function builds a gpu model in keras with a tensorflow backend to smooth label maps. + This model replaces each voxel of the input by the value of its most numerous neighbour. + :param label_shape: shape of the label map + :param label_list: list of all labels to consider + :param connectivity: (optional) connectivity to use when smoothing the label maps + :return: gpu smoothing model + """ + + # convert labels so values are in [0, ..., N-1] and use one hot encoding + n_labels = label_list.shape[0] + labels_in = KL.Input(shape=label_shape, name="lab_input", dtype="int32") + labels = ConvertLabels(label_list)(labels_in) + labels = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, dtype="int32"), depth=n_labels, axis=-1) + )(labels) + + # count neighbouring voxels + n_dims, _ = utils.get_dims(label_shape) + k = utils.add_axis( + utils.build_binary_structure(connectivity, n_dims, shape=n_dims), axis=[-1, -1] + ) + kernel = KL.Lambda(lambda x: tf.convert_to_tensor(k, dtype="float32"))([]) + split = KL.Lambda(lambda x: tf.split(x, [1] * n_labels, axis=-1))(labels) + labels = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding="SAME"))( + [split[0], kernel] + ) + for i in range(1, n_labels): + tmp = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding="SAME"))( + [split[i], kernel] + ) + labels = KL.Lambda(lambda x: tf.concat([x[0], x[1]], -1))([labels, tmp]) + + # take the argmax and convert labels to original values + labels = KL.Lambda(lambda x: tf.math.argmax(x, -1))(labels) + labels = ConvertLabels(np.arange(n_labels), label_list)(labels) + return Model(inputs=labels_in, outputs=labels) + + +def erode_labels_in_dir( + labels_dir, + result_dir, + labels_to_erode, + erosion_factors=1.0, + gpu=False, + recompute=True, +): + """Erode a given set of label values for all label maps in a folder. + :param labels_dir: path of directory with input label maps + :param result_dir: path of directory where cropped label maps will be writen + :param labels_to_erode: list of label values to erode + :param erosion_factors: (optional) list of erosion factors to use for each label value. If values are integers, + normal erosion applies. If float, we first 1) blur a mask of the corresponding label value with a gpu model, + and 2) use the erosion factor as a threshold in the blurred mask. + If erosion_factors is a single value, the same factor will be applied to all labels. + :param gpu: (optional) whether to use a fast gpu model for blurring (if erosion factors are floats) + :param recompute: (optional) whether to recompute result files even if they already exists + """ + # create result dir + utils.mkdir(result_dir) + + # loop over label maps + model = None + path_labels = utils.list_images_in_folder(labels_dir) + loop_info = utils.LoopInfo(len(path_labels), 5, "eroding", True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # erode label map + labels, aff, h = utils.load_volume(path_label, im_only=False) + path_result = os.path.join(result_dir, os.path.basename(path_label)) + if (not os.path.isfile(path_result)) | recompute: + labels, model = erode_label_map( + labels, labels_to_erode, erosion_factors, gpu, model, return_model=True + ) + utils.save_volume(labels, aff, h, path_result) + + +def upsample_labels_in_dir( + labels_dir, + target_res, + result_dir, + path_label_list=None, + path_freesurfer="/usr/local/freesurfer/", + recompute=True, +): + """This function upsamples all label maps within a folder. Importantly, each label map is converted into probability + maps for all label values, and all these maps are upsampled separately. The upsampled label maps are recovered by + taking the argmax of the label values probability maps. + :param labels_dir: path of directory with label maps to upsample + :param target_res: resolution at which to upsample the label maps. can be a single number (isotropic), or a list. + :param result_dir: path of directory where the upsampled label maps will be writen + :param path_label_list: (optional) path of numpy array containing all label values. + Computed automatically if not given. + :param path_freesurfer: (optional) path freesurfer home (upsampling performed with mri_convert) + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # prepare result dir + utils.mkdir(result_dir) + + # set up FreeSurfer + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") + + # list label maps + path_labels = utils.list_images_in_folder(labels_dir) + labels_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_labels[0], max_channels=3 + ) + + # build command + target_res = utils.reformat_to_list(target_res, length=n_dims) + post_cmd = "-voxsize " + " ".join([str(r) for r in target_res]) + " -odt float" + + # load label list and corresponding LUT to make sure that labels go from 0 to N-1 + label_list, _ = utils.get_list_labels( + path_label_list, labels_dir=path_labels, FS_sort=False + ) + new_label_list = np.arange(len(label_list), dtype="int32") + lut = utils.get_mapping_lut(label_list) + + # loop over label maps + loop_info = utils.LoopInfo(len(path_labels), 5, "upsampling", True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + path_result = os.path.join(result_dir, os.path.basename(path_label)) + if (not os.path.isfile(path_result)) | recompute: + + # load volume + labels, aff, h = utils.load_volume(path_label, im_only=False) + labels = lut[labels.astype("int")] + + # create individual folders for label map + basefilename = utils.strip_extension(os.path.basename(path_label)) + indiv_label_dir = os.path.join(result_dir, basefilename) + upsample_indiv_label_dir = os.path.join( + result_dir, basefilename + "_upsampled" + ) + utils.mkdir(indiv_label_dir) + utils.mkdir(upsample_indiv_label_dir) + + # loop over label values + for label in new_label_list: + path_mask = os.path.join(indiv_label_dir, str(label) + ".nii.gz") + path_mask_upsampled = os.path.join( + upsample_indiv_label_dir, str(label) + ".nii.gz" + ) + if not os.path.isfile(path_mask): + mask = (labels == label) * 1.0 + utils.save_volume(mask, aff, h, path_mask) + if not os.path.isfile(path_mask_upsampled): + cmd = utils.mkcmd( + mri_convert, path_mask, path_mask_upsampled, post_cmd + ) + os.system(cmd) + + # compute argmax of upsampled probability maps (upload them one at a time) + probmax, aff, h = utils.load_volume( + os.path.join(upsample_indiv_label_dir, "0.nii.gz"), im_only=False + ) + labels = np.zeros(probmax.shape, dtype="int") + for label in new_label_list: + prob = utils.load_volume( + os.path.join(upsample_indiv_label_dir, str(label) + ".nii.gz") + ) + idx = prob > probmax + labels[idx] = label + probmax[idx] = prob[idx] + utils.save_volume(label_list[labels], aff, h, path_result, dtype="int32") + + +def compute_hard_volumes_in_dir( + labels_dir, + voxel_volume=None, + path_label_list=None, + skip_background=True, + path_numpy_result=None, + path_csv_result=None, + FS_sort=False, +): + """Compute hard volumes of structures for all label maps in a folder. + :param labels_dir: path of directory with input label maps + :param voxel_volume: (optional) volume of the voxels. If None, it will be directly inferred from the file header. + Set to 1 for a voxel count. + :param path_label_list: (optional) list of labels to compute volumes for. + Can be an int, a sequence, or a numpy array. If None, the volumes of all label values are computed for each subject. + :param skip_background: (optional) whether to skip computing the volume of the background. + If label_list is None, this assumes background value is 0. + If label_list is not None, this assumes the background is the first value in label list. + :param path_numpy_result: (optional) path where to write the result volumes as a numpy array. + :param path_csv_result: (optional) path where to write the results as csv file. + :param FS_sort: (optional) whether to sort the labels in FreeSurfer order. + :return: numpy array with the volume of each structure for all subjects. + Rows represent label values, and columns represent subjects. + """ + + # create result directories + if path_numpy_result is not None: + utils.mkdir(os.path.dirname(path_numpy_result)) + if path_csv_result is not None: + utils.mkdir(os.path.dirname(path_csv_result)) + + # load or compute labels list + label_list, _ = utils.get_list_labels(path_label_list, labels_dir, FS_sort=FS_sort) + + # create csv volume file if necessary + if path_csv_result is not None: + if skip_background: + cvs_header = [["subject"] + [str(lab) for lab in label_list[1:]]] + else: + cvs_header = [["subject"] + [str(lab) for lab in label_list]] + with open(path_csv_result, "w") as csvFile: + writer = csv.writer(csvFile) + writer.writerows(cvs_header) + csvFile.close() + + # loop over label maps + path_labels = utils.list_images_in_folder(labels_dir) + if skip_background: + volumes = np.zeros((label_list.shape[0] - 1, len(path_labels))) + else: + volumes = np.zeros((label_list.shape[0], len(path_labels))) + loop_info = utils.LoopInfo(len(path_labels), 10, "processing", True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # load segmentation, and compute unique labels + labels, _, _, _, _, _, subject_res = utils.get_volume_info( + path_label, return_volume=True + ) + if voxel_volume is None: + voxel_volume = float(np.prod(subject_res)) + subject_volumes = compute_hard_volumes( + labels, voxel_volume, label_list, skip_background + ) + volumes[:, idx] = subject_volumes + + # write volumes + if path_csv_result is not None: + subject_volumes = np.around(volumes[:, idx], 3) + row = [utils.strip_suffix(os.path.basename(path_label))] + [ + str(vol) for vol in subject_volumes + ] + with open(path_csv_result, "a") as csvFile: + writer = csv.writer(csvFile) + writer.writerow(row) + csvFile.close() + + # write numpy array if necessary + if path_numpy_result is not None: + np.save(path_numpy_result, volumes) + + return volumes + + +def build_atlas( + labels_dir, + label_list, + align_centre_of_mass=False, + margin=15, + shape=None, + path_atlas=None, +): + """This function builds a binary atlas (defined by label values > 0) from several label maps. + :param labels_dir: path of directory with input label maps + :param label_list: list of all labels in the label maps. If there is more than 1 value here, the different channels + of the atlas (each corresponding to the probability map of a given label) will in the same order as in this list. + :param align_centre_of_mass: whether to build the atlas by aligning the center of mass of each label map. + If False, the atlas has the same size as the input label maps, which are assumed to be aligned. + :param margin: (optional) If align_centre_of_mass is True, margin by which to crop the input label maps around + their center of mass. Therefore it controls the size of the output atlas: (2*margin + 1)**n_dims. + :param shape: shape of the output atlas. + :param path_atlas: (optional) path where the output atlas will be writen. + Default is None, where the atlas is not saved.""" + + # list of all label maps and create result dir + path_labels = utils.list_images_in_folder(labels_dir) + n_label_maps = len(path_labels) + utils.mkdir(os.path.dirname(path_atlas)) + + # read list labels and create lut + label_list = np.array( + utils.reformat_to_list(label_list, load_as_numpy=True, dtype="int") + ) + lut = utils.get_mapping_lut(label_list) + n_labels = len(label_list) + + # create empty atlas + im_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_labels[0], aff_ref=np.eye(4) + ) + if align_centre_of_mass: + shape = [margin * 2] * n_dims + else: + shape = ( + utils.reformat_to_list(shape, length=n_dims) + if shape is not None + else im_shape + ) + shape = shape + [n_labels] if n_labels > 1 else shape + atlas = np.zeros(shape) + + # loop over label maps + loop_info = utils.LoopInfo(n_label_maps, 10, "processing", True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # load label map and build mask + lab = utils.load_volume(path_label, dtype="int32", aff_ref=np.eye(4)) + lab = correct_label_map(lab, [31, 63, 72], [4, 43, 0]) + lab = lut[lab.astype("int")] + lab = pad_volume(lab, shape[:n_dims]) + lab = crop_volume(lab, cropping_shape=shape[:n_dims]) + indices = np.where(lab > 0) + + if len(label_list) > 1: + lab = np.identity(n_labels)[lab] + + # crop label map around centre of mass + if align_centre_of_mass: + centre_of_mass = np.array( + [np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])], + dtype="int32", + ) + min_crop = centre_of_mass - margin + max_crop = centre_of_mass + margin + atlas += lab[ + min_crop[0] : max_crop[0], + min_crop[1] : max_crop[1], + min_crop[2] : max_crop[2], + ..., + ] + # otherwise just add the one-hot labels + else: + atlas += lab + + # normalise atlas and save it if necessary + atlas /= n_label_maps + atlas = align_volume_to_ref(atlas, np.eye(4), aff_ref=aff, n_dims=n_dims) + if path_atlas is not None: + utils.save_volume(atlas, aff, h, path_atlas) + + return atlas + + +# ---------------------------------------------------- edit dataset ---------------------------------------------------- + + +def check_images_and_labels(image_dir, labels_dir, verbose=True): + """Check if corresponding images and labels have the same affine matrices and shapes. + Labels are matched to images by sorting order. + :param image_dir: path of directory with input images + :param labels_dir: path of directory with corresponding label maps + :param verbose: whether to print out info + """ + + # list images and labels + path_images = utils.list_images_in_folder(image_dir) + path_labels = utils.list_images_in_folder(labels_dir) + assert len(path_images) == len( + path_labels + ), "different number of files in image_dir and labels_dir" + + # loop over images and labels + loop_info = ( + utils.LoopInfo(len(path_images), 10, "checking", verbose) if verbose else None + ) + for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): + if loop_info is not None: + loop_info.update(idx) + + # load images and labels + im, aff_im, h_im = utils.load_volume(path_image, im_only=False) + lab, aff_lab, h_lab = utils.load_volume(path_label, im_only=False) + aff_im_list = np.round(aff_im, 2).tolist() + aff_lab_list = np.round(aff_lab, 2).tolist() + + # check matching affine and shape + if aff_lab_list != aff_im_list: + print("aff mismatch :\n" + path_image) + print(aff_im_list) + print(path_label) + print(aff_lab_list) + print("") + if lab.shape != im.shape: + print("shape mismatch :\n" + path_image) + print(im.shape) + print("\n" + path_label) + print(lab.shape) + print("") + + +def crop_dataset_to_minimum_size( + labels_dir, result_dir, image_dir=None, image_result_dir=None, margin=5 +): + """Crop all label maps in a directory to the minimum possible common size, with a margin. + This is achieved by cropping each label map individually to the minimum size, and by padding all the cropped maps to + the same size (taken to be the maximum size of the cropped maps). + If images are provided, they undergo the same transformations as their corresponding label maps. + :param labels_dir: path of directory with input label maps + :param result_dir: path of directory where cropped label maps will be writen + :param image_dir: (optional) if not None, the cropping will be applied to all images in this directory + :param image_result_dir: (optional) path of directory where cropped images will be writen + :param margin: (optional) margin to apply around the label maps during cropping + """ + + # create result dir + utils.mkdir(result_dir) + if image_dir is not None: + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" + utils.mkdir(image_result_dir) + + # list labels and images + path_labels = utils.list_images_in_folder(labels_dir) + if image_dir is not None: + path_images = utils.list_images_in_folder(image_dir) + else: + path_images = [None] * len(path_labels) + _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) + + # loop over label maps for cropping + print("\ncropping labels to individual minimum size") + maximum_size = np.zeros(n_dims) + loop_info = utils.LoopInfo(len(path_labels), 10, "cropping", True) + for idx, (path_label, path_image) in enumerate(zip(path_labels, path_images)): + loop_info.update(idx) + + # crop label maps and update maximum size of cropped map + label, aff, h = utils.load_volume(path_label, im_only=False) + label, cropping, aff = crop_volume_around_region(label, aff=aff) + utils.save_volume( + label, aff, h, os.path.join(result_dir, os.path.basename(path_label)) + ) + maximum_size = np.maximum( + maximum_size, np.array(label.shape) + margin * 2 + ) # *2 to add margin on each side + + # crop images if required + if path_image is not None: + image, aff_im, h_im = utils.load_volume(path_image, im_only=False) + image, aff_im = crop_volume_with_idx(image, cropping, aff=aff_im) + utils.save_volume( + image, + aff_im, + h_im, + os.path.join(image_result_dir, os.path.basename(path_image)), + ) + + # loop over label maps for padding + print("\npadding labels to same size") + loop_info = utils.LoopInfo(len(path_labels), 10, "padding", True) + for idx, (path_label, path_image) in enumerate(zip(path_labels, path_images)): + loop_info.update(idx) + + # pad label maps to maximum size + path_result = os.path.join(result_dir, os.path.basename(path_label)) + label, aff, h = utils.load_volume(path_result, im_only=False) + label, aff = pad_volume(label, maximum_size, aff=aff) + utils.save_volume(label, aff, h, path_result) + + # crop images if required + if path_image is not None: + path_result = os.path.join(image_result_dir, os.path.basename(path_image)) + image, aff, h = utils.load_volume(path_result, im_only=False) + image, aff = pad_volume(image, maximum_size, aff=aff) + utils.save_volume(image, aff, h, path_result) + + +def crop_dataset_around_region_of_same_size( + labels_dir, + result_dir, + image_dir=None, + image_result_dir=None, + margin=0, + recompute=True, +): + + # create result dir + utils.mkdir(result_dir) + if image_dir is not None: + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" + utils.mkdir(image_result_dir) + + # list labels and images + path_labels = utils.list_images_in_folder(labels_dir) + path_images = ( + utils.list_images_in_folder(image_dir) + if image_dir is not None + else [None] * len(path_labels) + ) + _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) + + recompute_labels = any( + [ + not os.path.isfile(os.path.join(result_dir, os.path.basename(path))) + for path in path_labels + ] + ) + if (image_dir is not None) & (not recompute_labels): + recompute_labels = any( + [ + not os.path.isfile( + os.path.join(image_result_dir, os.path.basename(path)) + ) + for path in path_images + ] + ) + + # get minimum patch shape so that no labels are left out when doing the cropping later on + max_crop_shape = np.zeros(n_dims) + if recompute_labels: + for path_label in path_labels: + label, aff, _ = utils.load_volume(path_label, im_only=False) + label = align_volume_to_ref(label, aff, aff_ref=np.eye(4)) + label = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) + _, cropping = crop_volume_around_region(label) + max_crop_shape = np.maximum( + cropping[n_dims:] - cropping[:n_dims], max_crop_shape + ) + max_crop_shape += np.array( + utils.reformat_to_list(margin, length=n_dims, dtype="int") + ) + print("max_crop_shape: ", max_crop_shape) + + # crop shapes (possibly with padding if images are smaller than crop shape) + for path_label, path_image in zip(path_labels, path_images): + + path_label_result = os.path.join(result_dir, os.path.basename(path_label)) + path_image_result = os.path.join(image_result_dir, os.path.basename(path_image)) + + if ( + (not os.path.isfile(path_image_result)) + | (not os.path.isfile(path_label_result)) + | recompute + ): + # load labels + label, aff, h_la = utils.load_volume( + path_label, im_only=False, dtype="int32" + ) + label, aff_new = align_volume_to_ref( + label, aff, aff_ref=np.eye(4), return_aff=True + ) + vol_shape = np.array(label.shape[:n_dims]) + if path_image is not None: + image, _, h_im = utils.load_volume(path_image, im_only=False) + image = align_volume_to_ref(image, aff, aff_ref=np.eye(4)) + else: + image = h_im = None + + # mask labels + mask = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) + label[np.logical_not(mask)] = 0 + + # find cropping indices + indices = np.nonzero(mask) + min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) + max_idx = np.minimum( + np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape + ) + + # expand/retract (depending on the desired shape) the cropping region around the centre + intermediate_vol_shape = max_idx - min_idx + min_idx = min_idx - np.int32( + np.ceil((max_crop_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((max_crop_shape - intermediate_vol_shape) / 2) + ) + + # check if we need to pad the output to the desired shape + min_padding = np.abs(np.minimum(min_idx, 0)) + max_padding = np.maximum(max_idx - vol_shape, 0) + if np.any(min_padding > 0) | np.any(max_padding > 0): + pad_margins = tuple( + [(min_padding[i], max_padding[i]) for i in range(n_dims)] + ) + else: + pad_margins = None + cropping = np.concatenate( + [np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)] + ) + + # crop volume + label = crop_volume_with_idx(label, cropping, n_dims=n_dims) + if path_image is not None: + image = crop_volume_with_idx(image, cropping, n_dims=n_dims) + + # pad volume if necessary + if pad_margins is not None: + label = np.pad(label, pad_margins, mode="constant", constant_values=0) + if path_image is not None: + _, n_channels = utils.get_dims(image.shape) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) + if n_channels > 1 + else pad_margins + ) + image = np.pad( + image, pad_margins, mode="constant", constant_values=0 + ) + + # update aff + if n_dims == 2: + min_idx = np.append(min_idx, 0) + aff_new[0:3, -1] = aff_new[0:3, -1] + aff_new[:3, :3] @ min_idx + + # write labels + label, aff_final = align_volume_to_ref( + label, aff_new, aff_ref=aff, return_aff=True + ) + utils.save_volume(label, aff_final, h_la, path_label_result, dtype="int32") + if path_image is not None: + image = align_volume_to_ref(image, aff_new, aff_ref=aff) + utils.save_volume(image, aff_final, h_im, path_image_result) + + +def crop_dataset_around_region( + image_dir, + labels_dir, + image_result_dir, + labels_result_dir, + margin=0, + cropping_shape_div_by=None, + recompute=True, +): + + # create result dir + utils.mkdir(image_result_dir) + utils.mkdir(labels_result_dir) + + # list volumes and masks + path_images = utils.list_images_in_folder(image_dir) + path_labels = utils.list_images_in_folder(labels_dir) + _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_labels[0]) + + # loop over images and labels + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) + for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): + loop_info.update(idx) + + path_label_result = os.path.join( + labels_result_dir, os.path.basename(path_label) + ) + path_image_result = os.path.join(image_result_dir, os.path.basename(path_image)) + + if ( + (not os.path.isfile(path_label_result)) + | (not os.path.isfile(path_image_result)) + | recompute + ): + + image, aff, h_im = utils.load_volume(path_image, im_only=False) + label, _, h_lab = utils.load_volume(path_label, im_only=False) + mask = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) + label[np.logical_not(mask)] = 0 + vol_shape = np.array(label.shape[:n_dims]) + + # find cropping indices + indices = np.nonzero(mask) + min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) + max_idx = np.minimum( + np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape + ) + + # expand/retract (depending on the desired shape) the cropping region around the centre + intermediate_vol_shape = max_idx - min_idx + cropping_shape = np.array( + [ + utils.find_closest_number_divisible_by_m( + s, cropping_shape_div_by, answer_type="higher" + ) + for s in intermediate_vol_shape + ] + ) + min_idx = min_idx - np.int32( + np.ceil((cropping_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((cropping_shape - intermediate_vol_shape) / 2) + ) + + # check if we need to pad the output to the desired shape + min_padding = np.abs(np.minimum(min_idx, 0)) + max_padding = np.maximum(max_idx - vol_shape, 0) + if np.any(min_padding > 0) | np.any(max_padding > 0): + pad_margins = tuple( + [(min_padding[i], max_padding[i]) for i in range(n_dims)] + ) + else: + pad_margins = None + cropping = np.concatenate( + [np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)] + ) + + # crop volume + label = crop_volume_with_idx(label, cropping, n_dims=n_dims) + image = crop_volume_with_idx(image, cropping, n_dims=n_dims) + + # pad volume if necessary + if pad_margins is not None: + label = np.pad(label, pad_margins, mode="constant", constant_values=0) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) + if n_channels > 1 + else pad_margins + ) + image = np.pad(image, pad_margins, mode="constant", constant_values=0) + + # update aff + if n_dims == 2: + min_idx = np.append(min_idx, 0) + aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ min_idx + + # write results + utils.save_volume(image, aff, h_im, path_image_result) + utils.save_volume(label, aff, h_lab, path_label_result, dtype="int32") + + +def subdivide_dataset_to_patches( + patch_shape, + image_dir=None, + image_result_dir=None, + labels_dir=None, + labels_result_dir=None, + full_background=True, + remove_after_dividing=False, +): + """This function subdivides images and/or label maps into several smaller patches of specified shape. + :param patch_shape: shape of patches to create. Can either be an int, a sequence, or a 1d numpy array. + :param image_dir: (optional) path of directory with input images + :param image_result_dir: (optional) path of directory where image patches will be writen + :param labels_dir: (optional) path of directory with input label maps + :param labels_result_dir: (optional) path of directory where label map patches will be writen + :param full_background: (optional) whether to keep patches only labelled as background (only if label maps are + provided). + :param remove_after_dividing: (optional) whether to delete input images after having divided them in smaller + patches. This enables to save disk space in the subdivision process. + """ + + # create result dir and list images and label maps + assert (image_dir is not None) | ( + labels_dir is not None + ), "at least one of image_dir or labels_dir should not be None." + if image_dir is not None: + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" + utils.mkdir(image_result_dir) + path_images = utils.list_images_in_folder(image_dir) + else: + path_images = None + if labels_dir is not None: + assert ( + labels_result_dir is not None + ), "labels_result_dir should not be None if labels_dir is specified" + utils.mkdir(labels_result_dir) + path_labels = utils.list_images_in_folder(labels_dir) + else: + path_labels = None + if path_images is None: + path_images = [None] * len(path_labels) + if path_labels is None: + path_labels = [None] * len(path_images) + + # reformat path_shape + patch_shape = utils.reformat_to_list(patch_shape) + n_dims, _ = utils.get_dims(patch_shape) + + # loop over images and labels + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) + for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): + loop_info.update(idx) + + # load image and labels + if path_image is not None: + im, aff_im, h_im = utils.load_volume( + path_image, im_only=False, squeeze=False + ) + else: + im = aff_im = h_im = None + if path_label is not None: + lab, aff_lab, h_lab = utils.load_volume( + path_label, im_only=False, squeeze=True + ) + else: + lab = aff_lab = h_lab = None + + # get volume shape + if path_image is not None: + shape = im.shape + else: + shape = lab.shape + + # crop image and label map to size divisible by patch_shape + new_size = np.array( + [ + utils.find_closest_number_divisible_by_m(shape[i], patch_shape[i]) + for i in range(n_dims) + ] + ) + crop = np.round((np.array(shape[:n_dims]) - new_size) / 2).astype("int") + crop = np.concatenate((crop, crop + new_size), axis=0) + if (im is not None) & (n_dims == 2): + im = im[crop[0] : crop[2], crop[1] : crop[3], ...] + elif (im is not None) & (n_dims == 3): + im = im[crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ...] + if (lab is not None) & (n_dims == 2): + lab = lab[crop[0] : crop[2], crop[1] : crop[3], ...] + elif (lab is not None) & (n_dims == 3): + lab = lab[crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ...] + + # loop over patches + n_im = 0 + n_crop = (new_size / patch_shape).astype("int") + for i in range(n_crop[0]): + i *= patch_shape[0] + for j in range(n_crop[1]): + j *= patch_shape[1] + + if n_dims == 2: + + # crop volumes + if lab is not None: + temp_la = lab[ + i : i + patch_shape[0], j : j + patch_shape[1], ... + ] + else: + temp_la = None + if im is not None: + temp_im = im[ + i : i + patch_shape[0], j : j + patch_shape[1], ... + ] + else: + temp_im = None + + # write patches + if temp_la is not None: + if full_background | (not (temp_la == 0).all()): + n_im += 1 + utils.save_volume( + temp_la, + aff_lab, + h_lab, + os.path.join( + labels_result_dir, + os.path.basename( + path_label.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) + if temp_im is not None: + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) + else: + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace(".nii.gz", "_%d.nii.gz" % n_im) + ), + ), + ) + + elif n_dims == 3: + for k in range(n_crop[2]): + k *= patch_shape[2] + + # crop volumes + if lab is not None: + temp_la = lab[ + i : i + patch_shape[0], + j : j + patch_shape[1], + k : k + patch_shape[2], + ..., + ] + else: + temp_la = None + if im is not None: + temp_im = im[ + i : i + patch_shape[0], + j : j + patch_shape[1], + k : k + patch_shape[2], + ..., + ] + else: + temp_im = None + + # write patches + if temp_la is not None: + if full_background | (not (temp_la == 0).all()): + n_im += 1 + utils.save_volume( + temp_la, + aff_lab, + h_lab, + os.path.join( + labels_result_dir, + os.path.basename( + path_label.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) + if temp_im is not None: + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) + else: + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) + + if remove_after_dividing: + if path_image is not None: + os.remove(path_image) + if path_label is not None: + os.remove(path_label) diff --git a/nobrainer/ext/lab2im/image_generator.py b/nobrainer/ext/lab2im/image_generator.py new file mode 100644 index 00000000..d8f83bc0 --- /dev/null +++ b/nobrainer/ext/lab2im/image_generator.py @@ -0,0 +1,309 @@ +""" +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + +# python imports +import numpy as np +import numpy.random as npr + +# project imports +from nobrainer.ext.lab2im import edit_volumes, utils +from nobrainer.ext.lab2im.lab2im_model import lab2im_model + + +class ImageGenerator: + + def __init__( + self, + labels_dir, + generation_labels=None, + output_labels=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + blur_range=1.15, + ): + """ + This class is wrapper around the lab2im_model model. It contains the GPU model that generates images from labels + maps, and a python generator that supplies the input data for this model. + To generate pairs of image/labels you can just call the method generate_image() on an object of this class. + + :param labels_dir: path of folder with all input label maps, or to a single label map. + + # IMPORTANT !!! + # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), + # these values refer to the RAS axes. + + # label maps-related parameters + :param generation_labels: (optional) list of all possible label values in the input label maps. + Default is None, where the label values are directly gotten from the provided label maps. + If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array. + :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in + the label maps returned by this function, i.e. all occurrences of generation_labels[i] in the input label maps + will be converted to output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + + # output-related parameters + :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. + :param n_channels: (optional) number of channels to be synthetised. Default is 1. + :param target_res: (optional) target resolution of the generated images and corresponding label maps. + If None, the outputs will have the same resolution as the input label maps. + Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array. + :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image. + Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites + output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or + the path to a 1d numpy array. + + # GMM-sampling parameters + :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity + distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence, a + 1d numpy array, or the path to a 1d numpy array. + It should have the same length as generation_labels, and contain values between 0 and K-1, where K is the total + number of classes. Default is all labels have different classes (K=len(generation_labels)). + :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. + Can either be 'uniform', or 'normal'. Default is 'uniform'. + :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because + these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: + 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is + uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each + mini_batch from the same distribution. + 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is + not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each + mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from + N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. + 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived + from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a + modality from the n_mod possibilities, and we sample the GMM means like in 2). + If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel + (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. + 4) the path to such a numpy array. + Default is None, which corresponds to prior_means = [25, 225]. + :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. + Default is None, which corresponds to prior_stds = [5, 25]. + :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be + only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. + + # blurring parameters + :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is + given or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a c + coefficient sampled from a uniform distribution with bounds [1/blur_range, blur_range]. + If None, no randomisation. Default is 1.15. + """ + + # prepare data files + self.labels_paths = utils.list_images_in_folder(labels_dir) + + # generation parameters + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = ( + utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + ) + self.n_channels = n_channels + if generation_labels is not None: + self.generation_labels = utils.load_array_if_path(generation_labels) + else: + self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir) + if output_labels is not None: + self.output_labels = utils.load_array_if_path(output_labels) + else: + self.output_labels = self.generation_labels + self.target_res = utils.load_array_if_path(target_res) + self.batchsize = batchsize + # preliminary operations + self.output_shape = utils.load_array_if_path(output_shape) + self.output_div_by_n = output_div_by_n + # GMM parameters + self.prior_distributions = prior_distributions + if generation_classes is not None: + self.generation_classes = utils.load_array_if_path(generation_classes) + assert ( + self.generation_classes.shape == self.generation_labels.shape + ), "if provided, generation labels should have the same shape as generation_labels" + unique_classes = np.unique(self.generation_classes) + assert np.array_equal( + unique_classes, np.arange(np.max(unique_classes) + 1) + ), "generation_classes should a linear range between 0 and its maximum value." + else: + self.generation_classes = np.arange(self.generation_labels.shape[0]) + self.prior_means = utils.load_array_if_path(prior_means) + self.prior_stds = utils.load_array_if_path(prior_stds) + self.use_specific_stats_for_channel = use_specific_stats_for_channel + + # blurring parameters + self.blur_range = blur_range + + # build transformation model + self.labels_to_image_model, self.model_output_shape = self._build_lab2im_model() + + # build generator for model inputs + self.model_inputs_generator = self._build_model_inputs( + len(self.generation_labels) + ) + + # build brain generator + self.image_generator = self._build_image_generator() + + def _build_lab2im_model(self): + # build_model + lab_to_im_model = lab2im_model( + labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + blur_range=self.blur_range, + ) + out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] + return lab_to_im_model, out_shape + + def _build_image_generator(self): + while True: + model_inputs = next(self.model_inputs_generator) + [image, labels] = self.labels_to_image_model.predict(model_inputs) + yield image, labels + + def generate_image(self): + """call this method when an object of this class has been instantiated to generate new brains""" + (image, labels) = next(self.image_generator) + # put back images in native space + list_images = list() + list_labels = list() + for i in range(self.batchsize): + list_images.append( + edit_volumes.align_volume_to_ref( + image[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + list_labels.append( + edit_volumes.align_volume_to_ref( + labels[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + image = np.stack(list_images, axis=0) + labels = np.stack(list_labels, axis=0) + return np.squeeze(image), np.squeeze(labels) + + def _build_model_inputs(self, n_labels): + + # get label info + _, _, n_dims, _, _, _ = utils.get_volume_info(self.labels_paths[0]) + + # Generate! + while True: + + # randomly pick as many images as batchsize + indices = npr.randint(len(self.labels_paths), size=self.batchsize) + + # initialise input lists + list_label_maps = [] + list_means = [] + list_stds = [] + + for idx in indices: + + # load label in identity space, and add them to inputs + y = utils.load_volume( + self.labels_paths[idx], dtype="int", aff_ref=np.eye(4) + ) + list_label_maps.append(utils.add_axis(y, axis=[0, -1])) + + # add means and standard deviations to inputs + means = np.empty((1, n_labels, 0)) + stds = np.empty((1, n_labels, 0)) + for channel in range(self.n_channels): + + # retrieve channel specific stats if necessary + if isinstance(self.prior_means, np.ndarray): + if ( + self.prior_means.shape[0] > 2 + ) & self.use_specific_stats_for_channel: + if self.prior_means.shape[0] / 2 != self.n_channels: + raise ValueError( + "the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_means = self.prior_means[ + 2 * channel : 2 * channel + 2, : + ] + else: + tmp_prior_means = self.prior_means + else: + tmp_prior_means = self.prior_means + if isinstance(self.prior_stds, np.ndarray): + if ( + self.prior_stds.shape[0] > 2 + ) & self.use_specific_stats_for_channel: + if self.prior_stds.shape[0] / 2 != self.n_channels: + raise ValueError( + "the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_stds = self.prior_stds[ + 2 * channel : 2 * channel + 2, : + ] + else: + tmp_prior_stds = self.prior_stds + else: + tmp_prior_stds = self.prior_stds + + # draw means and std devs from priors + tmp_classes_means = utils.draw_value_from_distribution( + tmp_prior_means, + n_labels, + self.prior_distributions, + 125.0, + 100.0, + positive_only=True, + ) + tmp_classes_stds = utils.draw_value_from_distribution( + tmp_prior_stds, + n_labels, + self.prior_distributions, + 15.0, + 10.0, + positive_only=True, + ) + tmp_means = utils.add_axis( + tmp_classes_means[self.generation_classes], axis=[0, -1] + ) + tmp_stds = utils.add_axis( + tmp_classes_stds[self.generation_classes], axis=[0, -1] + ) + means = np.concatenate([means, tmp_means], axis=-1) + stds = np.concatenate([stds, tmp_stds], axis=-1) + list_means.append(means) + list_stds.append(stds) + + # build list of inputs of augmentation model + list_inputs = [list_label_maps, list_means, list_stds] + if ( + self.batchsize > 1 + ): # concatenate individual input types if batchsize > 1 + list_inputs = [np.concatenate(item, 0) for item in list_inputs] + else: + list_inputs = [item[0] for item in list_inputs] + + yield list_inputs diff --git a/nobrainer/ext/lab2im/lab2im_model.py b/nobrainer/ext/lab2im/lab2im_model.py new file mode 100644 index 00000000..96b0ee8b --- /dev/null +++ b/nobrainer/ext/lab2im/lab2im_model.py @@ -0,0 +1,218 @@ +""" +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + +import keras.layers as KL +from keras.models import Model + +# python imports +import numpy as np + +# project imports +from nobrainer.ext.lab2im import layers, utils +from nobrainer.ext.lab2im.edit_tensors import ( + blurring_sigma_for_downsampling, + resample_tensor, +) + + +def lab2im_model( + labels_shape, + n_channels, + generation_labels, + output_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + blur_range=1.15, +): + """ + This function builds a keras/tensorflow model to generate images from provided label maps. + The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. + The model will take as inputs: + -a label map + -a vector containing the means of the Gaussian Mixture Model for each label, + -a vector containing the standard deviations of the Gaussian Mixture Model for each label, + -an array of size batch*(n_dims+1)*(n_dims+1) representing a linear transformation + The model returns: + -the generated image normalised between 0 and 1. + -the corresponding label map, with only the labels present in output_labels (the other are reset to zero). + :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array. + :param n_channels: number of channels to be synthetised. + :param generation_labels: list of all possible label values in the input label maps. + Can be a sequence or a 1d numpy array. + :param output_labels: list of the same length as generation_labels to indicate which values to use in the label maps + returned by this model, i.e. all occurrences of generation_labels[i] in the input label maps will be converted to + output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] if you wish to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + :param atlas_res: resolution of the input label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param target_res: target resolution of the generated images and corresponding label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param output_shape: (optional) desired shape of the output images. + If the atlas and target resolutions are the same, the output will be cropped to output_shape, and if the two + resolutions are different, the output will be resized with trilinear interpolation to output_shape. + Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape + if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. + :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given + or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled + from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15. + """ + + # reformat resolutions + labels_shape = utils.reformat_to_list(labels_shape) + n_dims, _ = utils.get_dims(labels_shape) + atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0] + target_res = ( + atlas_res + if (target_res is None) + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) + + # get shapes + crop_shape, output_shape = get_shapes( + labels_shape, output_shape, atlas_res, target_res, output_div_by_n + ) + + # define model inputs + labels_input = KL.Input( + shape=labels_shape + [1], name="labels_input", dtype="int32" + ) + means_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="means_input" + ) + stds_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="stds_input" + ) + + # deform labels + labels = layers.RandomSpatialDeformation(inter_method="nearest")(labels_input) + + # cropping + if crop_shape != labels_shape: + labels._keras_shape = tuple(labels.get_shape().as_list()) + labels = layers.RandomCrop(crop_shape)(labels) + + # build synthetic image + labels._keras_shape = tuple(labels.get_shape().as_list()) + image = layers.SampleConditionalGMM(generation_labels)( + [labels, means_input, stds_input] + ) + + # apply bias field + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.BiasFieldCorruption(0.3, 0.025, same_bias_for_all_channels=False)( + image + ) + + # intensity augmentation + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=0.2)(image) + + # blur image + sigma = blurring_sigma_for_downsampling(atlas_res, target_res) + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.GaussianBlur(sigma=sigma, random_blur_range=blur_range)(image) + + # resample to target res + if crop_shape != output_shape: + image = resample_tensor(image, output_shape, interp_method="linear") + labels = resample_tensor(labels, output_shape, interp_method="nearest") + + # reset unwanted labels to zero + labels = layers.ConvertLabels( + generation_labels, dest_values=output_labels, name="labels_out" + )(labels) + + # build model (dummy layer enables to keep the labels when plugging this model to other models) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) + brain_model = Model( + inputs=[labels_input, means_input, stds_input], outputs=[image, labels] + ) + + return brain_model + + +def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n): + + n_dims = len(atlas_res) + + # get resampling factor + if atlas_res.tolist() != target_res.tolist(): + resample_factor = [atlas_res[i] / float(target_res[i]) for i in range(n_dims)] + else: + resample_factor = None + + # output shape specified, need to get cropping shape, and resample shape if necessary + if output_shape is not None: + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype="int") + + # make sure that output shape is smaller or equal to label shape + if resample_factor is not None: + output_shape = [ + min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) + for i in range(n_dims) + ] + else: + output_shape = [ + min(labels_shape[i], output_shape[i]) for i in range(n_dims) + ] + + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + tmp_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] + if output_shape != tmp_shape: + print( + "output shape {0} not divisible by {1}, changed to {2}".format( + output_shape, output_div_by_n, tmp_shape + ) + ) + output_shape = tmp_shape + + # get cropping and resample shape + if resample_factor is not None: + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] + else: + cropping_shape = output_shape + + # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n + else: + cropping_shape = labels_shape + if resample_factor is not None: + output_shape = [ + int(np.around(cropping_shape[i] * resample_factor[i], 0)) + for i in range(n_dims) + ] + else: + output_shape = cropping_shape + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + output_shape = [ + utils.find_closest_number_divisible_by_m( + s, output_div_by_n, answer_type="closer" + ) + for s in output_shape + ] + + return cropping_shape, output_shape diff --git a/nobrainer/ext/lab2im/layers.py b/nobrainer/ext/lab2im/layers.py new file mode 100644 index 00000000..c171e428 --- /dev/null +++ b/nobrainer/ext/lab2im/layers.py @@ -0,0 +1,2661 @@ +""" +This file regroups several custom keras layers used in the generation model: + - RandomSpatialDeformation, + - RandomCrop, + - RandomFlip, + - SampleConditionalGMM, + - SampleResolution, + - GaussianBlur, + - DynamicGaussianBlur, + - MimicAcquisition, + - BiasFieldCorruption, + - IntensityAugmentation, + - DiceLoss, + - WeightedL2Loss, + - ResetValuesToZero, + - ConvertLabels, + - PadAroundCentre, + - MaskEdges + - ImageGradients + - RandomDilationErosion + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + +# python imports +import keras +import keras.backend as K +from keras.layers import Layer +import numpy as np +import tensorflow as tf + +# project imports +from nobrainer.ext.lab2im import edit_tensors as l2i_et +from nobrainer.ext.lab2im import utils + +# third-party imports +from nobrainer.ext.neuron import utils as nrn_utils +import nobrainer.ext.neuron.layers as nrn_layers + + +class RandomSpatialDeformation(Layer): + """This layer spatially deforms one or several tensors with a combination of affine and elastic transformations. + The input tensors are expected to have the same shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + The non-linear deformation is obtained by: + 1) a small-size SVF is sampled from a centred normal distribution of random standard deviation. + 2) it is resized with trilinear interpolation to half the shape of the input tensor + 3) it is integrated to obtain a diffeomorphic transformation + 4) finally, it is resized (again with trilinear interpolation) to full image size + :param scaling_bounds: (optional) range of the random scaling to apply. The scaling factor for each dimension is + sampled from a uniform distribution of predefined bounds. Can either be: + 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds + [1-scaling_bounds, 1+scaling_bounds] for each dimension. + 2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds + (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension. + 3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution + of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. + 4) False, in which case scaling is completely turned off. + Default is scaling_bounds = 0.15 (case 1) + :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1 + and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]]. + Default is rotation_bounds = 15. + :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. + :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we + encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator). + :param enable_90_rotations: (optional) whether to rotate the input by a random angle chosen in {0, 90, 180, 270}. + This is done regardless of the value of rotation_bounds. If true, a different value is sampled for each dimension. + :param nonlin_std: (optional) maximum value of the standard deviation of the normal distribution from which we + sample the small-size SVF. Set to 0 if you wish to completely turn the elastic deformation off. + :param nonlin_scale: (optional) if nonlin_std is not False, factor between the shapes of the input tensor + and the shape of the input non-linear tensor. + :param inter_method: (optional) interpolation method when deforming the input tensor. Can be 'linear', or 'nearest' + :param prob_deform: (optional) probability to apply spatial deformation + """ + + def __init__( + self, + scaling_bounds=0.15, + rotation_bounds=10, + shearing_bounds=0.02, + translation_bounds=False, + enable_90_rotations=False, + nonlin_std=4.0, + nonlin_scale=0.0625, + inter_method="linear", + prob_deform=1, + **kwargs + ): + + # shape attributes + self.n_inputs = 1 + self.inshape = None + self.n_dims = None + self.small_shape = None + + # deformation attributes + self.scaling_bounds = scaling_bounds + self.rotation_bounds = rotation_bounds + self.shearing_bounds = shearing_bounds + self.translation_bounds = translation_bounds + self.enable_90_rotations = enable_90_rotations + self.nonlin_std = nonlin_std + self.nonlin_scale = nonlin_scale + + # boolean attributes + self.apply_affine_trans = ( + (self.scaling_bounds is not False) + | (self.rotation_bounds is not False) + | (self.shearing_bounds is not False) + | (self.translation_bounds is not False) + | self.enable_90_rotations + ) + self.apply_elastic_trans = self.nonlin_std > 0 + self.prob_deform = prob_deform + + # interpolation methods + self.inter_method = inter_method + + super(RandomSpatialDeformation, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["scaling_bounds"] = self.scaling_bounds + config["rotation_bounds"] = self.rotation_bounds + config["shearing_bounds"] = self.shearing_bounds + config["translation_bounds"] = self.translation_bounds + config["enable_90_rotations"] = self.enable_90_rotations + config["nonlin_std"] = self.nonlin_std + config["nonlin_scale"] = self.nonlin_scale + config["inter_method"] = self.inter_method + config["prob_deform"] = self.prob_deform + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + inputshape = [input_shape] + else: + self.n_inputs = len(input_shape) + inputshape = input_shape + self.inshape = inputshape[0][1:] + self.n_dims = len(self.inshape) - 1 + + if self.apply_elastic_trans: + self.small_shape = utils.get_resample_shape( + self.inshape[: self.n_dims], self.nonlin_scale, self.n_dims + ) + else: + self.small_shape = None + + self.inter_method = utils.reformat_to_list( + self.inter_method, length=self.n_inputs, dtype="str" + ) + + self.built = True + super(RandomSpatialDeformation, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # reformat inputs and get its shape + if self.n_inputs < 2: + inputs = [inputs] + types = [v.dtype for v in inputs] + inputs = [tf.cast(v, dtype="float32") for v in inputs] + batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] + + # initialise list of transforms to operate + list_trans = list() + + # add affine deformation to inputs list + if self.apply_affine_trans: + affine_trans = utils.sample_affine_transform( + batchsize, + self.n_dims, + self.rotation_bounds, + self.scaling_bounds, + self.shearing_bounds, + self.translation_bounds, + self.enable_90_rotations, + ) + list_trans.append(affine_trans) + + # prepare non-linear deformation field and add it to inputs list + if self.apply_elastic_trans: + + # sample small field from normal distribution of specified std dev + trans_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.small_shape, dtype="int32")], + axis=0, + ) + trans_std = tf.random.uniform((1, 1), maxval=self.nonlin_std) + elastic_trans = tf.random.normal(trans_shape, stddev=trans_std) + + # reshape this field to half size (for smoother SVF), integrate it, and reshape to full image size + resize_shape = [ + max(int(self.inshape[i] / 2), self.small_shape[i]) + for i in range(self.n_dims) + ] + elastic_trans = nrn_layers.Resize( + size=resize_shape, interp_method="linear" + )(elastic_trans) + elastic_trans = nrn_layers.VecInt()(elastic_trans) + elastic_trans = nrn_layers.Resize( + size=self.inshape[: self.n_dims], interp_method="linear" + )(elastic_trans) + list_trans.append(elastic_trans) + + # apply deformations and return tensors with correct dtype + if self.apply_affine_trans | self.apply_elastic_trans: + if self.prob_deform == 1: + inputs = [ + nrn_layers.SpatialTransformer(m)([v] + list_trans) + for (m, v) in zip(self.inter_method, inputs) + ] + else: + rand_trans = tf.squeeze( + K.less(tf.random.uniform([1], 0, 1), self.prob_deform) + ) + inputs = [ + K.switch( + rand_trans, + nrn_layers.SpatialTransformer(m)([v] + list_trans), + v, + ) + for (m, v) in zip(self.inter_method, inputs) + ] + if self.n_inputs < 2: + return tf.cast(inputs[0], types[0]) + else: + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + + +class RandomCrop(Layer): + """Randomly crop all input tensors to a given shape. This cropping is applied to all channels. + The input tensors are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param crop_shape: list with cropping shape in each dimension (excluding batch and channel dimension) + + example: + if input is a tensor of shape [batchsize, 160, 160, 160, 3], + output = RandomCrop(crop_shape=[96, 128, 96])(input) + will yield an output of shape [batchsize, 96, 128, 96, 3] that is obtained by cropping with randomly selected + cropping indices. + """ + + def __init__(self, crop_shape, **kwargs): + + self.several_inputs = True + self.crop_max_val = None + self.crop_shape = crop_shape + self.n_dims = len(crop_shape) + self.list_n_channels = None + super(RandomCrop, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["crop_shape"] = self.crop_shape + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + self.several_inputs = False + inputshape = [input_shape] + else: + inputshape = input_shape + self.crop_max_val = np.array( + np.array(inputshape[0][1 : self.n_dims + 1]) + ) - np.array(self.crop_shape) + self.list_n_channels = [i[-1] for i in inputshape] + self.built = True + super(RandomCrop, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # if one input only is provided, performs the cropping directly + if not self.several_inputs: + return tf.map_fn(self._single_slice, inputs, dtype=inputs.dtype) + + # otherwise we concatenate all inputs before cropping, so that they are all cropped at the same location + else: + types = [v.dtype for v in inputs] + inputs = tf.concat([tf.cast(v, "float32") for v in inputs], axis=-1) + inputs = tf.map_fn(self._single_slice, inputs, dtype=tf.float32) + inputs = tf.split(inputs, self.list_n_channels, axis=-1) + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + + def _single_slice(self, vol): + crop_idx = tf.cast( + tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), "float32"), + dtype="int32", + ) + crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype="int32")], axis=0) + crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype="int32") + return tf.slice(vol, begin=crop_idx, size=crop_size) + + def compute_output_shape(self, input_shape): + output_shape = [ + tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels + ] + return output_shape if self.several_inputs else output_shape[0] + + +class RandomFlip(Layer): + """This layer randomly flips the input tensor along the specified axes with a specified probability. + It can also take multiple tensors as inputs (if they have the same shape). The same flips will be applied to all + input tensors. These are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + If specified, this layer can also swap corresponding values. This is especially useful when flipping label maps + with different labels for right/left structures, such that the flipped label maps keep a consistent labelling. + :param axis: integer, or list of integers specifying the dimensions along which to flip. + If a list, the input tensors can be flipped simultaneously in several directions. The values in flip_axis exclude + the batch dimension (e.g. 0 will flip the tensor along the first axis after the batch dimension). + Default is None, where the tensors can be flipped along all axes (except batch and channel axes). + :param swap_labels: boolean to specify whether to swap the values of each input. Values are only swapped if an odd + number of flips is applied. + Can also be a list if several tensors are given as input. + All the inputs for which the values need to be swapped must be int32 or int64. + :param label_list: if swap_labels is True, list of all labels contained in labels. Must be ordered as follows, first + the neutral labels (i.e. non-sided), then left labels and right labels. + :param n_neutral_labels: if swap_labels is True, number of non-sided labels + :param prob: probability to flip along each specified axis + + example 1: + if input is a tensor of shape (batchsize, 10, 100, 200, 3) + output = RandomFlip()(input) will randomly flip input along one of the 1st, 2nd, or 3rd axis (i.e. those with shape + 10, 100, 200). + + example 2: + if input is a tensor of shape (batchsize, 10, 100, 200, 3) + output = RandomFlip(flip_axis=1)(input) will randomly flip input along the 3rd axis (with shape 100), i.e. the axis + with index 1 if we don't count the batch axis. + + example 3: + input = tf.convert_to_tensor(np.array([[1, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 0, 0, 0]])) + label_list = np.array([0, 1, 2]) + n_neutral_labels = 1 + output = RandomFlip(flip_axis=1, swap_labels=True, label_list=label_list, n_neutral_labels=n_neutral_labels)(input) + where output will either be equal to input (bear in mind the flipping occurs with a 0.5 probability), or: + output = [[0, 0, 0, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 0, 0, 0, 0, 0, 2]] + Note that the input must have a dtype int32 or int64 for its values to be swapped, otherwise an error will be raised + + example 4: + if labels is the same as in the input of example 3, and image is a float32 image, then we can swap consistently both + the labels and the image with: + labels, image = RandomFlip(flip_axis=1, swap_labels=[True, False], label_list=label_list, + n_neutral_labels=n_neutral_labels)([labels, image]]) + Note that the labels must have a dtype int32 or int64 to be swapped, otherwise an error will be raised. + This doesn't concern the image input, as its values are not swapped. + """ + + def __init__( + self, + axis=None, + swap_labels=False, + label_list=None, + n_neutral_labels=None, + prob=0.5, + **kwargs + ): + + # shape attributes + self.several_inputs = True + self.n_dims = None + self.list_n_channels = None + + # axis along which to flip + self.axis = utils.reformat_to_list(axis) + self.flip_axes = None + + # whether to swap labels, and corresponding label list + self.swap_labels = utils.reformat_to_list(swap_labels) + self.label_list = label_list + self.n_neutral_labels = n_neutral_labels + self.swap_lut = None + + self.prob = prob + + super(RandomFlip, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["axis"] = self.axis + config["swap_labels"] = self.swap_labels + config["label_list"] = self.label_list + config["n_neutral_labels"] = self.n_neutral_labels + config["prob"] = self.prob + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + self.several_inputs = False + inputshape = [input_shape] + else: + inputshape = input_shape + self.n_dims = len(inputshape[0][1:-1]) + self.list_n_channels = [i[-1] for i in inputshape] + self.swap_labels = utils.reformat_to_list( + self.swap_labels, length=len(inputshape) + ) + self.flip_axes = ( + np.arange(self.n_dims).tolist() if self.axis is None else self.axis + ) + + # create label list with swapped labels + if any(self.swap_labels): + assert (self.label_list is not None) & ( + self.n_neutral_labels is not None + ), "please provide a label_list, and n_neutral_labels when swapping the values of at least one input" + n_labels = len(self.label_list) + if self.n_neutral_labels == n_labels: + self.swap_labels = [False] * len(self.swap_labels) + else: + rl_split = np.split( + self.label_list, + [ + self.n_neutral_labels, + self.n_neutral_labels + + int((n_labels - self.n_neutral_labels) / 2), + ], + ) + label_list_swap = np.concatenate( + (rl_split[0], rl_split[2], rl_split[1]) + ) + swap_lut = utils.get_mapping_lut(self.label_list, label_list_swap) + self.swap_lut = tf.convert_to_tensor(swap_lut, dtype="int32") + + self.built = True + super(RandomFlip, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # convert inputs to list, and get each input type + inputs = [inputs] if not self.several_inputs else inputs + types = [v.dtype for v in inputs] + + # store whether to flip along each specified dimension + batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] + size = tf.concat( + [batchsize, len(self.flip_axes) * tf.ones(1, dtype="int32")], axis=0 + ) + rand_flip = K.less(tf.random.uniform(size, 0, 1), self.prob) + + # swap right/left labels if we apply an odd number of flips + odd = ( + tf.math.floormod( + tf.reduce_sum(tf.cast(rand_flip, "int32"), -1, keepdims=True), 2 + ) + != 0 + ) + swapped_inputs = list() + for i in range(len(inputs)): + if self.swap_labels[i]: + swapped_inputs.append( + tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i]) + ) + else: + swapped_inputs.append(inputs[i]) + + # flip inputs and convert them back to their original type + inputs = tf.concat([tf.cast(v, "float32") for v in swapped_inputs], axis=-1) + inputs = tf.map_fn(self._single_flip, [inputs, rand_flip], dtype=tf.float32) + inputs = tf.split(inputs, self.list_n_channels, axis=-1) + + if self.several_inputs: + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + else: + return tf.cast(inputs[0], types[0]) + + def _single_swap(self, inputs): + return K.switch(inputs[1], tf.gather(self.swap_lut, inputs[0]), inputs[0]) + + @staticmethod + def _single_flip(inputs): + flip_axis = tf.where(inputs[1]) + return K.switch( + tf.equal(tf.size(flip_axis), 0), + inputs[0], + tf.reverse(inputs[0], axis=flip_axis[..., 0]), + ) + + +class SampleConditionalGMM(Layer): + """This layer generates an image by sampling a Gaussian Mixture Model conditioned on a label map given as input. + The parameters of the GMM are given as two additional inputs to the layer (means and standard deviations): + image = SampleConditionalGMM(generation_labels)([label_map, means, stds]) + + :param generation_labels: list of all possible label values contained in the input label maps. + Must be a list or a 1D numpy array of size N, where N is the total number of possible label values. + + Layer inputs: + label_map: input label map of shape [batchsize, shape_dim1, ..., shape_dimn, n_channel]. + All the values of label_map must be contained in generation_labels, but the input label_map doesn't necessarily have + to contain all the values in generation_labels. + means: tensor containing the mean values of all Gaussian distributions of the GMM. + It must be of shape [batchsize, N, n_channel], and in the same order as generation label, + i.e. the ith value of generation_labels will be associated to the ith value of means. + stds: same as means but for the standard deviations of the GMM. + """ + + def __init__(self, generation_labels, **kwargs): + self.generation_labels = generation_labels + self.n_labels = None + self.n_channels = None + self.max_label = None + self.indices = None + self.shape = None + super(SampleConditionalGMM, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["generation_labels"] = self.generation_labels + return config + + def build(self, input_shape): + + # check n_labels and n_channels + assert ( + len(input_shape) == 3 + ), "should have three inputs: labels, means, std devs (in that order)." + self.n_channels = input_shape[1][-1] + self.n_labels = len(self.generation_labels) + assert ( + self.n_labels == input_shape[1][1] + ), "means should have the same number of values as generation_labels" + assert ( + self.n_labels == input_shape[2][1] + ), "stds should have the same number of values as generation_labels" + + # scatter parameters (to build mean/std lut) + self.max_label = np.max(self.generation_labels) + 1 + indices = np.concatenate( + [ + self.generation_labels + self.max_label * i + for i in range(self.n_channels) + ], + axis=-1, + ) + self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype="int32") + self.indices = tf.convert_to_tensor( + utils.add_axis(indices, axis=[0, -1]), dtype="int32" + ) + + self.built = True + super(SampleConditionalGMM, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # reformat labels and scatter indices + batch = tf.split(tf.shape(inputs[0]), [1, -1])[0] + tmp_indices = tf.tile( + self.indices, + tf.concat([batch, tf.convert_to_tensor([1, 1], dtype="int32")], axis=0), + ) + labels = tf.concat( + [ + tf.cast(inputs[0], dtype="int32") + self.max_label * i + for i in range(self.n_channels) + ], + -1, + ) + + # build mean map + means = tf.concat([inputs[1][..., i] for i in range(self.n_channels)], 1) + tile_shape = tf.concat( + [ + batch, + tf.convert_to_tensor( + [ + 1, + ], + dtype="int32", + ), + ], + axis=0, + ) + means = tf.tile( + tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape + ) + means_map = tf.map_fn( + lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32 + ) + + # same for stds + stds = tf.concat([inputs[2][..., i] for i in range(self.n_channels)], 1) + stds = tf.tile( + tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape + ) + stds_map = tf.map_fn( + lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32 + ) + + return stds_map * tf.random.normal(tf.shape(labels)) + means_map + + def compute_output_shape(self, input_shape): + return ( + input_shape[0] + if (self.n_channels == 1) + else tuple(list(input_shape[0][:-1]) + [self.n_channels]) + ) + + +class SampleResolution(Layer): + """Build a random resolution tensor by sampling a uniform distribution of provided range. + + You can use this layer in the following ways: + resolution = SampleConditionalGMM(min_resolution)() in this case resolution will be a tensor of shape (n_dims,), + where n_dims is the length of the min_resolution parameter (provided as a list, see below). + resolution = SampleConditionalGMM(min_resolution)(input), where input is a tensor for which the first dimension + represents the batch_size. In this case resolution will be a tensor of shape (batchsize, n_dims,). + + :param min_resolution: list of length n_dims specifying the inferior bounds of the uniform distributions to + sample from for each value. + :param max_res_iso: If not None, all the values of resolution will be equal to the same value, which is randomly + sampled at each minibatch in U(min_resolution, max_res_iso). + :param max_res_aniso: If not None, we first randomly select a direction i in the range [0, n_dims-1], and we sample + a value in the corresponding uniform distribution U(min_resolution[i], max_res_aniso[i]). + The other values of resolution will be set to min_resolution. + :param prob_iso: if both max_res_iso and max_res_aniso are specified, this allows to specify the probability of + sampling an isotropic resolution (therefore using max_res_iso) with respect to anisotropic resolution + (which would use max_res_aniso). + :param prob_min: if not zero, this allows to return with the specified probability an output resolution equal + to min_resolution. + :param return_thickness: if set to True, this layer will also return a thickness value of the same shape as + resolution, which will be sampled independently for each axis from the uniform distribution + U(min_resolution, resolution). + + """ + + def __init__( + self, + min_resolution, + max_res_iso=None, + max_res_aniso=None, + prob_iso=0.1, + prob_min=0.05, + return_thickness=True, + **kwargs + ): + + self.min_res = min_resolution + self.max_res_iso_input = max_res_iso + self.max_res_iso = None + self.max_res_aniso_input = max_res_aniso + self.max_res_aniso = None + self.prob_iso = prob_iso + self.prob_min = prob_min + self.return_thickness = return_thickness + self.n_dims = len(self.min_res) + self.add_batchsize = False + self.min_res_tens = None + super(SampleResolution, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["min_resolution"] = self.min_res + config["max_res_iso"] = self.max_res_iso + config["max_res_aniso"] = self.max_res_aniso + config["prob_iso"] = self.prob_iso + config["prob_min"] = self.prob_min + config["return_thickness"] = self.return_thickness + return config + + def build(self, input_shape): + + # check maximum resolutions + assert (self.max_res_iso_input is not None) | ( + self.max_res_aniso_input is not None + ), "at least one of maximum isotropic or anisotropic resolutions must be provided, received none" + + # reformat resolutions as numpy arrays + self.min_res = np.array(self.min_res) + if self.max_res_iso_input is not None: + self.max_res_iso = np.array(self.max_res_iso_input) + assert len(self.min_res) == len(self.max_res_iso), ( + "min and isotropic max resolution must have the same length, " + "had {0} and {1}".format(self.min_res, self.max_res_iso) + ) + if np.array_equal(self.min_res, self.max_res_iso): + self.max_res_iso = None + if self.max_res_aniso_input is not None: + self.max_res_aniso = np.array(self.max_res_aniso_input) + assert len(self.min_res) == len(self.max_res_aniso), ( + "min and anisotropic max resolution must have the same length, " + "had {} and {}".format(self.min_res, self.max_res_aniso) + ) + if np.array_equal(self.min_res, self.max_res_aniso): + self.max_res_aniso = None + + # check prob iso + if ( + (self.max_res_iso is not None) + & (self.max_res_aniso is not None) + & (self.prob_iso == 0) + ): + raise Exception( + "prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled" + ) + + if input_shape: + self.add_batchsize = True + + self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype="float32") + + self.built = True + super(SampleResolution, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if not self.add_batchsize: + shape = [self.n_dims] + dim = tf.random.uniform( + shape=(1, 1), minval=0, maxval=self.n_dims, dtype="int32" + ) + mask = tf.tensor_scatter_nd_update( + tf.zeros([self.n_dims], dtype="bool"), + dim, + tf.convert_to_tensor([True], dtype="bool"), + ) + else: + batch = tf.split(tf.shape(inputs), [1, -1])[0] + tile_shape = tf.concat( + [batch, tf.convert_to_tensor([1], dtype="int32")], axis=0 + ) + self.min_res_tens = tf.tile( + tf.expand_dims(self.min_res_tens, 0), tile_shape + ) + + shape = tf.concat( + [batch, tf.convert_to_tensor([self.n_dims], dtype="int32")], axis=0 + ) + indices = tf.stack( + [ + tf.range(0, batch[0]), + tf.random.uniform(batch, 0, self.n_dims, dtype="int32"), + ], + 1, + ) + mask = tf.tensor_scatter_nd_update( + tf.zeros(shape, dtype="bool"), indices, tf.ones(batch, dtype="bool") + ) + + # return min resolution as tensor if min=max + if (self.max_res_iso is None) & (self.max_res_aniso is None): + new_resolution = self.min_res_tens + + # sample isotropic resolution only + elif (self.max_res_iso is not None) & (self.max_res_aniso is None): + new_resolution_iso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_iso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution_iso, + ) + + # sample anisotropic resolution only + elif (self.max_res_iso is None) & (self.max_res_aniso is not None): + new_resolution_aniso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_aniso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + tf.where(mask, new_resolution_aniso, self.min_res_tens), + ) + + # sample either anisotropic or isotropic resolution + else: + new_resolution_iso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_iso + ) + new_resolution_aniso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_aniso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)), + new_resolution_iso, + tf.where(mask, new_resolution_aniso, self.min_res_tens), + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution, + ) + + if self.return_thickness: + return [ + new_resolution, + tf.random.uniform( + tf.shape(self.min_res_tens), self.min_res_tens, new_resolution + ), + ] + else: + return new_resolution + + def compute_output_shape(self, input_shape): + if self.return_thickness: + return ( + [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2 + ) + else: + return (None, self.n_dims) if self.add_batchsize else self.n_dims + + +class GaussianBlur(Layer): + """Applies gaussian blur to an input image. + The input image is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param sigma: standard deviation of the blurring kernels to apply. Can be a number, a list of length n_dims, or + a numpy array. + :param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where + sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds + [1/random_blur_range, random_blur_range]. + :param use_mask: (optional) whether a mask of the input will be provided as an additional layer input. This is used + to mask the blurred image, and to correct for edge blurring effects. + + example 1: + output = GaussianBlur(sigma=0.5)(input) will isotropically blur the input with a gaussian kernel of std 0.5. + + example 2: + if input is a tensor of shape [batchsize, 10, 100, 200, 2] + output = GaussianBlur(sigma=[0.5, 1, 10])(input) will blur the input a different gaussian kernel in each dimension. + + example 3: + output = GaussianBlur(sigma=0.5, random_blur_range=1.15)(input) + will blur the input a different gaussian kernel in each dimension, as each dimension will be associated with + a kernel, whose standard deviation will be uniformly sampled from [0.5/1.15; 0.5*1.15]. + + example 4: + output = GaussianBlur(sigma=0.5, use_mask=True)([input, mask]) + will 1) blur the input a different gaussian kernel in each dimension, 2) mask the blurred image with the provided + mask, and 3) correct for edge blurring effects. If the provided mask is not of boolean type, it will be thresholded + above positive values. + """ + + def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs): + self.sigma = utils.reformat_to_list(sigma) + assert np.all( + np.array(self.sigma) >= 0 + ), "sigma should be superior or equal to 0" + self.use_mask = use_mask + + self.n_dims = None + self.n_channels = None + self.blur_range = random_blur_range + self.stride = None + self.separable = None + self.kernels = None + self.convnd = None + super(GaussianBlur, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["sigma"] = self.sigma + config["random_blur_range"] = self.blur_range + config["use_mask"] = self.use_mask + return config + + def build(self, input_shape): + + # get shapes + if self.use_mask: + assert ( + len(input_shape) == 2 + ), "please provide a mask as second layer input when use_mask=True" + self.n_dims = len(input_shape[0]) - 2 + self.n_channels = input_shape[0][-1] + else: + self.n_dims = len(input_shape) - 2 + self.n_channels = input_shape[-1] + + # prepare blurring kernel + self.stride = [1] * (self.n_dims + 2) + self.sigma = utils.reformat_to_list(self.sigma, length=self.n_dims) + self.separable = np.linalg.norm(np.array(self.sigma)) > 5 + if self.blur_range is None: # fixed kernels + self.kernels = l2i_et.gaussian_kernel(self.sigma, separable=self.separable) + else: + self.kernels = None + + # prepare convolution + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) + + self.built = True + super(GaussianBlur, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if self.use_mask: + image = inputs[0] + mask = tf.cast(inputs[1], "bool") + else: + image = inputs + mask = None + + # redefine the kernels at each new step when blur_range is activated + if self.blur_range is not None: + self.kernels = l2i_et.gaussian_kernel( + self.sigma, blur_range=self.blur_range, separable=self.separable + ) + + if self.separable: + for k in self.kernels: + if k is not None: + image = tf.concat( + [ + self.convnd( + tf.expand_dims(image[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) + if self.use_mask: + maskb = tf.cast(mask, "float32") + maskb = tf.concat( + [ + self.convnd( + tf.expand_dims(maskb[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) + image = image / (maskb + K.epsilon()) + image = tf.where(mask, image, tf.zeros_like(image)) + else: + if any(self.sigma): + image = tf.concat( + [ + self.convnd( + tf.expand_dims(image[..., n], -1), + self.kernels, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) + if self.use_mask: + maskb = tf.cast(mask, "float32") + maskb = tf.concat( + [ + self.convnd( + tf.expand_dims(maskb[..., n], -1), + self.kernels, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) + image = image / (maskb + K.epsilon()) + image = tf.where(mask, image, tf.zeros_like(image)) + + return image + + +class DynamicGaussianBlur(Layer): + """Applies gaussian blur to an input image, where the standard deviation of the blurring kernel is provided as a + layer input, which enables to perform dynamic blurring (i.e. the blurring kernel can vary at each minibatch). + :param max_sigma: maximum value of the standard deviation that will be provided as input. This is used to compute + the size of the blurring kernels. It must be provided as a list of length n_dims. + :param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where + sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds + [1/random_blur_range, random_blur_range]. + + example: + blurred_image = DynamicGaussianBlur(max_sigma=[5.]*3, random_blurring_range=1.15)([image, sigma]) + will return a blurred version of image, where the standard deviation of each dimension (given as a tensor, and with + values lower than 5 for each axis) is multiplied by a random coefficient uniformly sampled from [1/1.15; 1.15]. + """ + + def __init__(self, max_sigma, random_blur_range=None, **kwargs): + self.max_sigma = max_sigma + self.n_dims = None + self.n_channels = None + self.convnd = None + self.blur_range = random_blur_range + self.separable = None + super(DynamicGaussianBlur, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["max_sigma"] = self.max_sigma + config["random_blur_range"] = self.blur_range + return config + + def build(self, input_shape): + assert ( + len(input_shape) == 2 + ), "sigma should be provided as an input tensor for dynamic blurring" + self.n_dims = len(input_shape[0]) - 2 + self.n_channels = input_shape[0][-1] + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) + self.max_sigma = utils.reformat_to_list(self.max_sigma, length=self.n_dims) + self.separable = np.linalg.norm(np.array(self.max_sigma)) > 5 + self.built = True + super(DynamicGaussianBlur, self).build(input_shape) + + def call(self, inputs, **kwargs): + image = inputs[0] + sigma = inputs[-1] + kernels = l2i_et.gaussian_kernel( + sigma, self.max_sigma, self.blur_range, self.separable + ) + if self.separable: + for kernel in kernels: + image = tf.map_fn(self._single_blur, [image, kernel], dtype=tf.float32) + else: + image = tf.map_fn(self._single_blur, [image, kernels], dtype=tf.float32) + return image + + def _single_blur(self, inputs): + if self.n_channels > 1: + split_channels = tf.split(inputs[0], [1] * self.n_channels, axis=-1) + blurred_channel = list() + for channel in split_channels: + blurred = self.convnd( + tf.expand_dims(channel, 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ) + blurred_channel.append(tf.squeeze(blurred, axis=0)) + output = tf.concat(blurred_channel, -1) + else: + output = self.convnd( + tf.expand_dims(inputs[0], 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ) + output = tf.squeeze(output, axis=0) + return output + + +class MimicAcquisition(Layer): + """ + Layer that takes an image as input, and simulates data that has been acquired at low resolution. + The output is obtained by resampling the input twice: + - first at a resolution given as an input (i.e. the "acquisition" resolution), + - then at the output resolution (specified output shape). + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param volume_res: resolution of the provided inputs. Must be a 1-D numpy array with n_dims elements. + :param min_subsample_res: lower bound of the acquisition resolutions to mimic (i.e. the input resolution must have + values higher than min-subsample_res). + :param resample_shape: shape of the output tensor + :param build_dist_map: whether to return distance maps as outputs. These indicate the distance between each voxel + and the nearest non-interpolated voxel (during the second resampling). + :param prob_noise: probability to apply noise injection + + example 1: + im_res = [1., 1., 1.] + low_res = [1., 1., 3.] + res = tf.convert_to_tensor([1., 1., 4.5]) + image is a tensor of shape (None, 256, 256, 256, 3) + resample_shape = [256, 256, 256] + output = MimicAcquisition(im_res, low_res, resample_shape)([image, res]) + output will be a tensor of shape (None, 256, 256, 256, 3), obtained by downsampling image to [1., 1., 4.5]. + and re-upsampling it at initial resolution (because resample_shape is equal to the input shape). In this example all + examples of the batch will be downsampled to the same resolution (because res has no batch dimension). + Note that the provided res must have higher values than min_low_res. + + example 2: + im_res = [1., 1., 1.] + min_low_res = [1., 1., 1.] + res is a tensor of shape (None, 3), obtained for example by using the SampleResolution layer (see above). + image is a tensor of shape (None, 256, 256, 256, 1) + resample_shape = [128, 128, 128] + output = MimicAcquisition(im_res, low_res, resample_shape)([image, res]) + output will be a tensor of shape (None, 128, 128, 128, 1), obtained by downsampling each examples of the batch to + the matching resolution in res, and resampling them all to half the initial resolution. + Note that the provided res must have higher values than min_low_res. + """ + + def __init__( + self, + volume_res, + min_subsample_res, + resample_shape, + build_dist_map=False, + noise_std=0, + prob_noise=0.95, + **kwargs + ): + + # resolutions and dimensions + self.volume_res = volume_res + self.min_subsample_res = min_subsample_res + self.n_dims = len(self.volume_res) + self.n_channels = None + self.add_batchsize = None + + # noise + self.noise_std = noise_std + self.prob_noise = prob_noise + + # input and output shapes + self.inshape = None + self.resample_shape = resample_shape + + # meshgrids for resampling + self.down_grid = None + self.up_grid = None + + # whether to return a map indicating the distance from the interpolated voxels, to acquired ones. + self.build_dist_map = build_dist_map + + super(MimicAcquisition, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["volume_res"] = self.volume_res + config["min_subsample_res"] = self.min_subsample_res + config["resample_shape"] = self.resample_shape + config["build_dist_map"] = self.build_dist_map + config["noise_std"] = self.noise_std + config["prob_noise"] = self.prob_noise + return config + + def build(self, input_shape): + + # set up input shape and acquisition shape + self.inshape = input_shape[0][1:] + self.n_channels = input_shape[0][-1] + self.add_batchsize = False if (input_shape[1][0] is None) else True + down_tensor_shape = np.int32( + np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res + ) + + # build interpolation meshgrids + self.down_grid = tf.expand_dims( + tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0 + ) + self.up_grid = tf.expand_dims( + tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0 + ) + + self.built = True + super(MimicAcquisition, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # sort inputs + assert ( + len(inputs) == 2 + ), "inputs must have two items, the tensor to resample, and the downsampling resolution" + vol = inputs[0] + subsample_res = tf.cast(inputs[1], dtype="float32") + vol = K.reshape(vol, [-1, *self.inshape]) # necessary for multi_gpu models + batchsize = tf.split(tf.shape(vol), [1, -1])[0] + tile_shape = tf.concat([batchsize, tf.ones([1], dtype="int32")], 0) + + # get downsampling and upsampling factors + if self.add_batchsize: + subsample_res = tf.tile(tf.expand_dims(subsample_res, 0), tile_shape) + down_shape = tf.cast( + tf.convert_to_tensor( + np.array(self.inshape[:-1]) * self.volume_res, dtype="float32" + ) + / subsample_res, + dtype="int32", + ) + down_zoom_factor = tf.cast( + down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype="float32" + ) + up_zoom_factor = tf.cast( + tf.convert_to_tensor(self.resample_shape, dtype="int32") / down_shape, + dtype="float32", + ) + + # downsample + down_loc = tf.tile( + self.down_grid, + tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype="int32")], 0), + ) + down_loc = tf.cast(down_loc, "float32") / l2i_et.expand_dims( + down_zoom_factor, axis=[1] * self.n_dims + ) + inshape_tens = tf.tile( + tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape + ) + inshape_tens = l2i_et.expand_dims(inshape_tens, axis=[1] * self.n_dims) + down_loc = K.clip(down_loc, 0.0, tf.cast(inshape_tens, "float32")) + vol = tf.map_fn(self._single_down_interpn, [vol, down_loc], tf.float32) + + # add noise with predefined probability + if self.noise_std > 0: + sample_shape = tf.concat( + [ + batchsize, + tf.ones([self.n_dims], dtype="int32"), + self.n_channels * tf.ones([1], dtype="int32"), + ], + 0, + ) + noise = tf.random.normal( + tf.shape(vol), + stddev=tf.random.uniform(sample_shape, maxval=self.noise_std), + ) + if self.prob_noise == 1: + vol += noise + else: + vol = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), + vol + noise, + vol, + ) + + # upsample + up_loc = tf.tile( + self.up_grid, + tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype="int32")], axis=0), + ) + up_loc = tf.cast(up_loc, "float32") / l2i_et.expand_dims( + up_zoom_factor, axis=[1] * self.n_dims + ) + vol = tf.map_fn(self._single_up_interpn, [vol, up_loc], tf.float32) + + # return upsampled volume + if not self.build_dist_map: + return vol + + # return upsampled volumes with distance maps + else: + + # get grid points + floor = tf.math.floor(up_loc) + ceil = tf.math.ceil(up_loc) + + # get distances of every voxel to higher and lower grid points for every dimension + f_dist = up_loc - floor + c_dist = ceil - up_loc + + # keep minimum 1d distances, and compute 3d distance to nearest grid point + dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims( + subsample_res, axis=[1] * self.n_dims + ) + dist = tf.math.sqrt( + tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True) + ) + + return [vol, dist] + + @staticmethod + def _single_down_interpn(inputs): + return nrn_utils.interpn(inputs[0], inputs[1], interp_method="nearest") + + @staticmethod + def _single_up_interpn(inputs): + return nrn_utils.interpn(inputs[0], inputs[1], interp_method="linear") + + def compute_output_shape(self, input_shape): + output_shape = tuple([None] + self.resample_shape + [input_shape[0][-1]]) + return [output_shape] * 2 if self.build_dist_map else output_shape + + +class BiasFieldCorruption(Layer): + """This layer applies a smooth random bias field to the input by applying the following steps: + 1) we first sample a value for the standard deviation of a centred normal distribution + 2) a small-size SVF is sampled from this normal distribution + 3) the small SVF is then resized with trilinear interpolation to image size + 4) it is rescaled to positive values by taking the voxel-wise exponential + 5) it is multiplied to the input tensor. + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param bias_field_std: maximum value of the standard deviation sampled in 1 (it will be sampled from the range + [0, bias_field_std]) + :param bias_scale: ratio between the shape of the input tensor and the shape of the sampled SVF. + :param same_bias_for_all_channels: whether to apply the same bias field to all the channels of the input tensor. + :param prob: probability to apply this bias field corruption. + """ + + def __init__( + self, + bias_field_std=0.5, + bias_scale=0.025, + same_bias_for_all_channels=False, + prob=0.95, + **kwargs + ): + + # input shape + self.several_inputs = False + self.inshape = None + self.n_dims = None + self.n_channels = None + + # sampling shape + self.std_shape = None + self.small_bias_shape = None + + # bias field parameters + self.bias_field_std = bias_field_std + self.bias_scale = bias_scale + self.same_bias_for_all_channels = same_bias_for_all_channels + self.prob = prob + + super(BiasFieldCorruption, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["bias_field_std"] = self.bias_field_std + config["bias_scale"] = self.bias_scale + config["same_bias_for_all_channels"] = self.same_bias_for_all_channels + config["prob"] = self.prob + return config + + def build(self, input_shape): + + # input shape + if isinstance(input_shape, list): + self.several_inputs = True + self.inshape = input_shape + else: + self.inshape = [input_shape] + self.n_dims = len(self.inshape[0]) - 2 + self.n_channels = self.inshape[0][-1] + + # sampling shapes + self.std_shape = [1] * (self.n_dims + 1) + self.small_bias_shape = utils.get_resample_shape( + self.inshape[0][1 : self.n_dims + 1], self.bias_scale, 1 + ) + if not self.same_bias_for_all_channels: + self.std_shape[-1] = self.n_channels + self.small_bias_shape[-1] = self.n_channels + + self.built = True + super(BiasFieldCorruption, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if not self.several_inputs: + inputs = [inputs] + + if self.bias_field_std > 0: + + # sampling shapes + batchsize = tf.split(tf.shape(inputs[0]), [1, -1])[0] + std_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.std_shape, dtype="int32")], 0 + ) + bias_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype="int32")], + axis=0, + ) + + # sample small bias field + bias_field = tf.random.normal( + bias_shape, + stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std), + ) + + # resize bias field and take exponential + bias_field = nrn_layers.Resize( + size=self.inshape[0][1 : self.n_dims + 1], interp_method="linear" + )(bias_field) + bias_field = tf.math.exp(bias_field) + + # apply bias field with predefined probability + if self.prob == 1: + return [tf.math.multiply(bias_field, v) for v in inputs] + else: + rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob)) + if self.several_inputs: + return [ + K.switch(rand_trans, tf.math.multiply(bias_field, v), v) + for v in inputs + ] + else: + return K.switch( + rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0] + ) + + else: + return inputs + + +class IntensityAugmentation(Layer): + """This layer enables to augment the intensities of the input tensor, as well as to apply min_max normalisation. + The following steps are applied (all are optional): + 1) white noise corruption, with a randomly sampled std dev. + 2) clip the input between two values + 3) min-max normalisation + 4) gamma augmentation (i.e. voxel-wise exponentiation by a randomly sampled power) + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param noise_std: maximum value of the standard deviation of the Gaussian white noise used in 1 (it will be sampled + from the range [0, noise_std]). Set to 0 to skip this step. + :param clip: clip the input tensor between the given values. Can either be: a number (in which case we clip between + 0 and the given value), or a list or a numpy array with two elements. Default is 0, where no clipping occurs. + :param normalise: whether to apply min-max normalisation, to normalise between 0 and 1. Default is True. + :param norm_perc: percentiles (between 0 and 1) of the sorted intensity values for robust normalisation. Can be: + a number (in which case the robust minimum is the provided percentile of sorted values, and the maximum is the + 1 - norm_perc percentile), or a list/numpy array of 2 elements (percentiles for the minimum and maximum values). + The minimum and maximum values are computed separately for each channel if separate_channels is True. + Default is 0, where we simply take the minimum and maximum values. + :param gamma_std: standard deviation of the normal distribution from which we sample gamma (in log domain). + Default is 0, where no gamma augmentation occurs. + :param contrast_inversion: whether to perform contrast inversion (i.e. 1 - x). If True, this is performed randomly + for each element of the batch, as well as for each channel. + :param separate_channels: whether to augment all channels separately. Default is True. + :param prob_noise: probability to apply noise injection + :param prob_gamma: probability to apply gamma augmentation + """ + + def __init__( + self, + noise_std=0, + clip=0, + normalise=True, + norm_perc=0, + gamma_std=0, + contrast_inversion=False, + separate_channels=True, + prob_noise=0.95, + prob_gamma=1, + **kwargs + ): + + # shape attributes + self.n_dims = None + self.n_channels = None + self.flatten_shape = None + self.expand_minmax_dim = None + self.one = None + + # inputs + self.noise_std = noise_std + self.clip = clip + self.clip_values = None + self.normalise = normalise + self.norm_perc = norm_perc + self.perc = None + self.gamma_std = gamma_std + self.separate_channels = separate_channels + self.contrast_inversion = contrast_inversion + self.prob_noise = prob_noise + self.prob_gamma = prob_gamma + + super(IntensityAugmentation, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["noise_std"] = self.noise_std + config["clip"] = self.clip + config["normalise"] = self.normalise + config["norm_perc"] = self.norm_perc + config["gamma_std"] = self.gamma_std + config["separate_channels"] = self.separate_channels + config["prob_noise"] = self.prob_noise + config["prob_gamma"] = self.prob_gamma + return config + + def build(self, input_shape): + self.n_dims = len(input_shape) - 2 + self.n_channels = input_shape[-1] + self.flatten_shape = np.prod(np.array(input_shape[1:-1])) + self.flatten_shape = ( + self.flatten_shape * self.n_channels + if not self.separate_channels + else self.flatten_shape + ) + self.expand_minmax_dim = ( + self.n_dims if self.separate_channels else self.n_dims + 1 + ) + self.one = tf.ones([1], dtype="int32") + if self.clip: + self.clip_values = utils.reformat_to_list(self.clip) + self.clip_values = ( + self.clip_values + if len(self.clip_values) == 2 + else [0, self.clip_values[0]] + ) + else: + self.clip_values = None + if self.norm_perc: + self.perc = utils.reformat_to_list(self.norm_perc) + self.perc = ( + self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]] + ) + else: + self.perc = None + + self.built = True + super(IntensityAugmentation, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # prepare shape for sampling the noise and gamma std dev (depending on whether we augment channels separately) + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + if (self.noise_std > 0) | (self.gamma_std > 0) | self.contrast_inversion: + sample_shape = tf.concat( + [batchsize, tf.ones([self.n_dims], dtype="int32")], 0 + ) + if self.separate_channels: + sample_shape = tf.concat([sample_shape, self.n_channels * self.one], 0) + else: + sample_shape = tf.concat([sample_shape, self.one], 0) + else: + sample_shape = None + + # add noise with predefined probability + if self.noise_std > 0: + noise_stddev = tf.random.uniform(sample_shape, maxval=self.noise_std) + if self.separate_channels: + noise = tf.random.normal(tf.shape(inputs), stddev=noise_stddev) + else: + noise = tf.random.normal( + tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev + ) + noise = tf.tile( + noise, + tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels]), + ) + if self.prob_noise == 1: + inputs = inputs + noise + else: + inputs = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), + inputs + noise, + inputs, + ) + + # clip images to given values + if self.clip_values is not None: + inputs = K.clip(inputs, self.clip_values[0], self.clip_values[1]) + + # normalise + if self.normalise: + # define robust min and max by sorting values and taking percentile + if self.perc is not None: + if self.separate_channels: + shape = tf.concat( + [ + batchsize, + self.flatten_shape * self.one, + self.n_channels * self.one, + ], + 0, + ) + else: + shape = tf.concat([batchsize, self.flatten_shape * self.one], 0) + intensities = tf.sort(tf.reshape(inputs, shape), axis=1) + m = intensities[:, max(int(self.perc[0] * self.flatten_shape), 0), ...] + M = intensities[ + :, + min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), + ..., + ] + # simple min and max + else: + m = K.min(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) + M = K.max(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) + # normalise + m = l2i_et.expand_dims(m, axis=[1] * self.expand_minmax_dim) + M = l2i_et.expand_dims(M, axis=[1] * self.expand_minmax_dim) + inputs = tf.clip_by_value(inputs, m, M) + inputs = (inputs - m) / (M - m + K.epsilon()) + + # apply voxel-wise exponentiation with predefined probability + if self.gamma_std > 0: + gamma = tf.random.normal(sample_shape, stddev=self.gamma_std) + if self.prob_gamma == 1: + inputs = tf.math.pow(inputs, tf.math.exp(gamma)) + else: + inputs = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)), + tf.math.pow(inputs, tf.math.exp(gamma)), + inputs, + ) + + # apply random contrast inversion + if self.contrast_inversion: + rand_invert = tf.less(tf.random.uniform(sample_shape, maxval=1), 0.5) + split_channels = tf.split(inputs, [1] * self.n_channels, axis=-1) + split_rand_invert = tf.split(rand_invert, [1] * self.n_channels, axis=-1) + inverted_channel = list() + for channel, invert in zip(split_channels, split_rand_invert): + inverted_channel.append( + tf.map_fn( + self._single_invert, [channel, invert], dtype=channel.dtype + ) + ) + inputs = tf.concat(inverted_channel, -1) + + return inputs + + @staticmethod + def _single_invert(inputs): + return K.switch(tf.squeeze(inputs[1]), 1 - inputs[0], inputs[0]) + + +class DiceLoss(Layer): + """This layer computes the soft Dice loss between two tensors. + These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: dice_loss = DiceLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures + when computing the loss. Default is 0 where no boundary weighting is applied. + :param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels + within this distance to a region boundary. Default is 3. + :param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be + redundant when we have several labels. This is only used if boundary_weight is not 0. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__( + self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs + ): + + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.boundary_weights = boundary_weights + self.boundary_dist = boundary_dist + self.skip_background = skip_background + self.enable_checks = enable_checks + self.spatial_axes = None + self.avg_pooling_layer = None + super(DiceLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["boundary_weights"] = self.boundary_weights + config["boundary_dist"] = self.boundary_dist + config["skip_background"] = self.skip_background + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert ( + len(input_shape) == 2 + ), "DiceLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + self.avg_pooling_layer = getattr(keras.layers, "AvgPool%dD" % n_dims) + self.skip_background = False if n_labels == 1 else self.skip_background + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") + self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) + + self.built = True + super(DiceLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] + pred = inputs[1] + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = K.clip( + gt + / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) + pred = K.clip( + pred + / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) + + # compute dice loss for each label + top = 2 * gt * pred + bottom = tf.math.square(gt) + tf.math.square(pred) + + # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) + if self.boundary_weights: + avg = self.avg_pooling_layer( + pool_size=2 * self.boundary_dist + 1, strides=1, padding="same" + )(gt) + boundaries = tf.cast(avg > 0.0, "float32") * tf.cast( + avg < (1 / len(self.spatial_axes) - 1e-4), "float32" + ) + if self.skip_background: + boundaries_channels = tf.unstack(boundaries, axis=-1) + boundaries = tf.stack( + [tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], + axis=-1, + ) + boundary_weights_tensor = 1 + self.boundary_weights * boundaries + top *= boundary_weights_tensor + bottom *= boundary_weights_tensor + else: + boundary_weights_tensor = None + + # compute loss + top = tf.math.reduce_sum(top, self.spatial_axes) + bottom = tf.math.reduce_sum(bottom, self.spatial_axes) + dice = (top + tf.keras.backend.epsilon()) / ( + bottom + tf.keras.backend.epsilon() + ) + loss = 1 - dice + + # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt + if ( + boundary_weights_tensor is not None + ): # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum( + gt * boundary_weights_tensor, self.spatial_axes + ) + else: + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + loss = tf.reduce_sum(loss * self.class_weights_tens, -1) + + return tf.math.reduce_mean(loss) + + def compute_output_shape(self, input_shape): + return [[]] + + +class WeightedL2Loss(Layer): + """This layer computes a L2 loss weighted by a specified factor (target_value) between two tensors. + This is designed to be used on the layer before the softmax. + The tensors are expected to have the same shape [batchsize, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: wl2_loss = WeightedL2Loss()([gt, pred]) + + :param target_value: target value for the layer before softmax: target_value when gt = 1, -target_value when gt = 0. + """ + + def __init__(self, target_value=5, **kwargs): + self.target_value = target_value + self.n_labels = None + super(WeightedL2Loss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["target_value"] = self.target_value + return config + + def build(self, input_shape): + assert ( + len(input_shape) == 2 + ), "DiceLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." + self.n_labels = input_shape[0][-1] + self.built = True + super(WeightedL2Loss, self).build(input_shape) + + def call(self, inputs, **kwargs): + gt = inputs[0] + pred = inputs[1] + weights = tf.expand_dims(1 - gt[..., 0] + 1e-8, -1) + return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / ( + K.sum(weights) * self.n_labels + ) + + def compute_output_shape(self, input_shape): + return [[]] + + +class CrossEntropyLoss(Layer): + """This layer computes the cross-entropy loss between two tensors. + These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: ce_loss = CrossEntropyLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures + when computing the loss. Default is 0 where no boundary weighting is applied. + :param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels + within this distance to a region boundary. Default is 3. + :param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be + redundant when we have several labels. This is only used if boundary_weight is not 0. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__( + self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs + ): + + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.boundary_weights = boundary_weights + self.boundary_dist = boundary_dist + self.skip_background = skip_background + self.enable_checks = enable_checks + self.spatial_axes = None + self.avg_pooling_layer = None + super(CrossEntropyLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["boundary_weights"] = self.boundary_weights + config["boundary_dist"] = self.boundary_dist + config["skip_background"] = self.skip_background + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert ( + len(input_shape) == 2 + ), "CrossEntropy expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + self.avg_pooling_layer = getattr(keras.layers, "AvgPool%dD" % n_dims) + self.skip_background = False if n_labels == 1 else self.skip_background + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") + self.class_weights_tens = l2i_et.expand_dims( + class_weights_tens, [0] * (1 + n_dims) + ) + + self.built = True + super(CrossEntropyLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] + pred = inputs[1] + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = K.clip( + gt + / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) + pred = pred / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) + pred = K.clip( + pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon() + ) # to avoid log(0) + + # compare prediction/target, ce has the same shape has the input tensors + ce = -gt * tf.math.log(pred) + + # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) + if self.boundary_weights: + avg = self.avg_pooling_layer( + pool_size=2 * self.boundary_dist + 1, strides=1, padding="same" + )(gt) + boundaries = tf.cast(avg > 0.0, "float32") * tf.cast( + avg < (1 / len(self.spatial_axes) - 1e-4), "float32" + ) + if self.skip_background: + boundaries_channels = tf.unstack(boundaries, axis=-1) + boundaries = tf.stack( + [tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], + axis=-1, + ) + boundary_weights_tensor = 1 + self.boundary_weights * boundaries + ce *= boundary_weights_tensor + else: + boundary_weights_tensor = None + + # apply class weighting across labels. By the end of this, ce still has the same shape has the input tensors. + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt + if ( + boundary_weights_tensor is not None + ): # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum( + gt * boundary_weights_tensor, self.spatial_axes, True + ) + else: + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + ce = tf.reduce_sum(ce * self.class_weights_tens, -1) + + # sum along label axis, and take the mean along spatial dimensions + ce = tf.math.reduce_mean(tf.math.reduce_sum(ce, axis=-1)) + + return ce + + def compute_output_shape(self, input_shape): + return [[]] + + +class MomentLoss(Layer): + """This layer computes a moment loss between two tensors. Specifically, it computes the distance between the centres + of gravity for all the channels of the two tensors, and then returns a value averaged across all channels. + These tensors are expected to have the same shape [batch, size_dim1, ..., size_dimN, n_channels]. + The first input tensor is the GT and the second is the prediction: moment_loss = MomentLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__(self, class_weights=None, enable_checks=False, **kwargs): + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.enable_checks = enable_checks + self.spatial_axes = None + self.coordinates = None + super(MomentLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert ( + len(input_shape) == 2 + ), "MomentLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + + # build coordinate meshgrid of size (1, dim1, dim2, ..., dimN, ndim, nchan) + self.coordinates = tf.stack(nrn_utils.volshape_to_ndgrid(inshape[:-1]), -1) + self.coordinates = tf.cast( + l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), + "float32", + ) + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") + self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) + + self.built = True + super(MomentLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] # (B, dim1, dim2, ..., dimN, nchan) + pred = inputs[1] + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = gt / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) + pred = pred / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) + + # compute loss + gt_mean_coordinates = self._mean_coordinates(gt) # (B, ndim, nchan) + pred_mean_coordinates = self._mean_coordinates(pred) + loss = tf.math.sqrt( + tf.reduce_sum( + tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1 + ) + ) # (B, nchan) + + # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + loss = tf.reduce_sum(loss * self.class_weights_tens, -1) + + return tf.math.reduce_mean(loss) + + def _mean_coordinates(self, tensor): + tensor = l2i_et.expand_dims( + tensor, axis=-2 + ) # (B, dim1, dim2, ..., dimN, 1, nchan) + numerator = tf.reduce_sum( + tensor * self.coordinates, axis=self.spatial_axes + ) # (B, ndim, nchan) + denominator = ( + tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon() + ) + return numerator / denominator + + def compute_output_shape(self, input_shape): + return [[]] + + +class ResetValuesToZero(Layer): + """This layer enables to reset given values to 0 within the input tensors. + + :param values: list of values to be reset to 0. + + example: + input = tf.convert_to_tensor(np.array([[1, 0, 2, 2, 2, 2, 0], + [1, 3, 3, 3, 3, 3, 3], + [1, 0, 0, 0, 4, 4, 4]])) + values = [1, 3] + ResetValuesToZero(values)(input) + >> [[0, 0, 2, 2, 2, 2, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 4, 4]] + """ + + def __init__(self, values, **kwargs): + assert ( + values is not None + ), "please provide correct list of values, received None" + self.values = utils.reformat_to_list(values) + self.values_tens = None + self.n_values = len(values) + super(ResetValuesToZero, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["values"] = self.values + return config + + def build(self, input_shape): + self.values_tens = tf.convert_to_tensor(self.values) + self.built = True + super(ResetValuesToZero, self).build(input_shape) + + def call(self, inputs, **kwargs): + values = tf.cast(self.values_tens, dtype=inputs.dtype) + for i in range(self.n_values): + inputs = tf.where( + tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs + ) + return inputs + + +class ConvertLabels(Layer): + """Convert all labels in a tensor by the corresponding given set of values. + labels_converted = ConvertLabels(source_values, dest_values)(labels). + labels must be an int32 tensor, and labels_converted will also be int32. + + :param source_values: list of all the possible values in labels. Must be a list or a 1D numpy array. + :param dest_values: list of all the target label values. Must be ordered the same as source values: + labels[labels == source_values[i]] = dest_values[i]. + If None (default), dest_values is equal to [0, ..., N-1], where N is the total number of values in source_values, + which enables to remap label maps to [0, ..., N-1]. + """ + + def __init__(self, source_values, dest_values=None, **kwargs): + self.source_values = source_values + self.dest_values = dest_values + self.lut = None + super(ConvertLabels, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["source_values"] = self.source_values + config["dest_values"] = self.dest_values + return config + + def build(self, input_shape): + self.lut = tf.convert_to_tensor( + utils.get_mapping_lut(self.source_values, dest=self.dest_values), + dtype="int32", + ) + self.built = True + super(ConvertLabels, self).build(input_shape) + + def call(self, inputs, **kwargs): + return tf.gather(self.lut, tf.cast(inputs, dtype="int32")) + + +class PadAroundCentre(Layer): + """Pad the input tensor to the specified shape with the given value. + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param pad_margin: margin to use for padding. The tensor will be padded by the provided margin on each side. + Can either be a number (all axes padded with the same margin), or a list/numpy array of length n_dims. + example: if tensor is of shape [batch, x, y, z, n_channels] and margin=10, then the padded tensor will be of + shape [batch, x+2*10, y+2*10, z+2*10, n_channels]. + :param pad_shape: shape to pad the tensor to. Can either be a number (all axes padded to the same shape), or a + list/numpy array of length n_dims. + :param value: value to pad the tensors with. Default is 0. + """ + + def __init__(self, pad_margin=None, pad_shape=None, value=0, **kwargs): + self.pad_margin = pad_margin + self.pad_shape = pad_shape + self.value = value + self.pad_margin_tens = None + self.pad_shape_tens = None + self.n_dims = None + super(PadAroundCentre, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["pad_margin"] = self.pad_margin + config["pad_shape"] = self.pad_shape + config["value"] = self.value + return config + + def build(self, input_shape): + # input shape + self.n_dims = len(input_shape) - 2 + shape = list(input_shape) + shape[0] = 0 + shape[-1] = 0 + + if self.pad_margin is not None: + assert ( + self.pad_shape is None + ), "please do not provide a padding shape and margin at the same time." + + # reformat padding margins + pad = np.transpose( + np.array( + [[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] + * 2 + ) + ) + self.pad_margin_tens = tf.convert_to_tensor(pad, dtype="int32") + + elif self.pad_shape is not None: + assert ( + self.pad_margin is None + ), "please do not provide a padding shape and margin at the same time." + + # pad shape + tensor_shape = tf.cast(tf.convert_to_tensor(shape), "int32") + self.pad_shape_tens = np.array( + [0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0] + ) + self.pad_shape_tens = tf.convert_to_tensor( + self.pad_shape_tens, dtype="int32" + ) + self.pad_shape_tens = tf.math.maximum(tensor_shape, self.pad_shape_tens) + + # padding margin + min_margins = (self.pad_shape_tens - tensor_shape) / 2 + max_margins = self.pad_shape_tens - tensor_shape - min_margins + self.pad_margin_tens = tf.stack([min_margins, max_margins], axis=-1) + + else: + raise Exception( + "please either provide a padding shape or a padding margin." + ) + + self.built = True + super(PadAroundCentre, self).build(input_shape) + + def call(self, inputs, **kwargs): + return tf.pad( + inputs, self.pad_margin_tens, mode="CONSTANT", constant_values=self.value + ) + + +class MaskEdges(Layer): + """Reset the edges of a tensor to zero (i.e. with bands of zeros along the specified axes). + The width of the zero-band is randomly drawn from a uniform distribution, whose range is given in boundaries. + + :param axes: axes along which to reset edges to zero. Can be an int (single axis), or a sequence. + :param boundaries: numpy array of shape (len(axes), 4). Each row contains the two bounds of the uniform + distributions from which we draw the width of the zero-bands on each side. + Those bounds must be expressed in relative side (i.e. between 0 and 1). + :return: a tensor of the same shape as the input, with bands of zeros along the specified axes. + + example: + tensor=tf.constant([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]) # shape = [1,10,10,1] + axes=1 + boundaries = np.array([[0.2, 0.45, 0.85, 0.9]]) + + In this case, we reset the edges along the 2nd dimension (i.e. the 1st dimension after the batch dimension), + the 1st zero-band will expand from the 1st row to a number drawn from [0.2*tensor.shape[1], 0.45*tensor.shape[1]], + and the 2nd zero-band will expand from a row drawn from [0.85*tensor.shape[1], 0.9*tensor.shape[1]], to the end of + the tensor. A possible output could be: + array([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]) # shape = [1,10,10,1] + """ + + def __init__(self, axes, boundaries, prob_mask=1, **kwargs): + self.axes = utils.reformat_to_list(axes, dtype="int") + self.boundaries = utils.reformat_to_n_channels_array( + boundaries, n_dims=4, n_channels=len(self.axes) + ) + self.prob_mask = prob_mask + self.inputshape = None + super(MaskEdges, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["axes"] = self.axes + config["boundaries"] = self.boundaries + config["prob_mask"] = self.prob_mask + return config + + def build(self, input_shape): + self.inputshape = input_shape + self.built = True + super(MaskEdges, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # build mask + mask = tf.ones_like(inputs) + for i, axis in enumerate(self.axes): + + # select restricting indices + axis_boundaries = self.boundaries[i, :] + idx1 = tf.math.round( + tf.random.uniform( + [1], + minval=axis_boundaries[0] * self.inputshape[axis], + maxval=axis_boundaries[1] * self.inputshape[axis], + ) + ) + idx2 = tf.math.round( + tf.random.uniform( + [1], + minval=axis_boundaries[2] * self.inputshape[axis], + maxval=axis_boundaries[3] * self.inputshape[axis] - 1, + ) + - idx1 + ) + idx3 = self.inputshape[axis] - idx1 - idx2 + split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype="int32") + + # update mask + split_list = tf.split(inputs, split_idx, axis=axis) + tmp_mask = tf.concat( + [ + tf.zeros_like(split_list[0]), + tf.ones_like(split_list[1]), + tf.zeros_like(split_list[2]), + ], + axis=axis, + ) + mask = mask * tmp_mask + + # mask second_channel + tensor = K.switch( + tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), + inputs * mask, + inputs, + ) + + return [tensor, mask] + + def compute_output_shape(self, input_shape): + return [input_shape] * 2 + + +class ImageGradients(Layer): + + def __init__(self, gradient_type="sobel", return_magnitude=False, **kwargs): + + self.gradient_type = gradient_type + assert (self.gradient_type == "sobel") | ( + self.gradient_type == "1-step_diff" + ), ( + "gradient_type should be either sobel or 1-step_diff, had %s" + % self.gradient_type + ) + + # shape + self.n_dims = 0 + self.shape = None + self.n_channels = 0 + + # convolution params if sobel diff + self.stride = None + self.kernels = None + self.convnd = None + + self.return_magnitude = return_magnitude + + super(ImageGradients, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["gradient_type"] = self.gradient_type + config["return_magnitude"] = self.return_magnitude + return config + + def build(self, input_shape): + + # get shapes + self.n_dims = len(input_shape) - 2 + self.shape = input_shape[1:] + self.n_channels = input_shape[-1] + + # prepare kernel if sobel gradients + if self.gradient_type == "sobel": + self.kernels = l2i_et.sobel_kernels(self.n_dims) + self.stride = [1] * (self.n_dims + 2) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) + else: + self.kernels = self.convnd = self.stride = None + + self.built = True + super(ImageGradients, self).build(input_shape) + + def call(self, inputs, **kwargs): + + image = inputs + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + gradients = list() + + # sobel method + if self.gradient_type == "sobel": + # get sobel gradients in each direction + for n in range(self.n_dims): + gradient = image + # apply 1D kernel in each direction (sobel kernels are separable), instead of applying a nD kernel + for k in self.kernels[n]: + gradient = tf.concat( + [ + self.convnd( + tf.expand_dims(gradient[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) + gradients.append(gradient) + + # 1-step method, only supports 2 and 3D + else: + + # get 1-step diff + if self.n_dims == 2: + gradients.append(image[:, 1:, :, :] - image[:, :-1, :, :]) # dx + gradients.append(image[:, :, 1:, :] - image[:, :, :-1, :]) # dy + + elif self.n_dims == 3: + gradients.append(image[:, 1:, :, :, :] - image[:, :-1, :, :, :]) # dx + gradients.append(image[:, :, 1:, :, :] - image[:, :, :-1, :, :]) # dy + gradients.append(image[:, :, :, 1:, :] - image[:, :, :, :-1, :]) # dz + + else: + raise Exception( + "ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD" + % self.n_dims + ) + + # pad with zeros to return tensors of the same shape as input + for i in range(self.n_dims): + tmp_shape = list(self.shape) + tmp_shape[i] = 1 + zeros = tf.zeros( + tf.concat( + [batchsize, tf.convert_to_tensor(tmp_shape, dtype="int32")], 0 + ), + image.dtype, + ) + gradients[i] = tf.concat([gradients[i], zeros], axis=i + 1) + + # compute total gradient magnitude if necessary, or concatenate different gradients along the channel axis + if self.return_magnitude: + gradients = tf.sqrt( + tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1) + ) + else: + gradients = tf.concat(gradients, axis=-1) + + return gradients + + def compute_output_shape(self, input_shape): + if not self.return_magnitude: + input_shape = list(input_shape) + input_shape[-1] = self.n_dims + return tuple(input_shape) + + +class RandomDilationErosion(Layer): + """ + GPU implementation of binary dilation or erosion. The operation can be chosen to be always a dilation, or always an + erosion, or randomly choosing between them for each element of the batch. + The chosen operation is applied to the input with a given probability. Moreover, it is also possible to randomise + the factor of the operation for each element of the mini-batch. + :param min_factor: minimum possible value for the dilation/erosion factor. Must be an integer. + :param max_factor: minimum possible value for the dilation/erosion factor. Must be an integer. + Set it to the same value as min_factor to always perform dilation/erosion with the same factor. + :param prob: probability with which to apply the selected operation to the input. + :param operation: which operation to apply. Can be 'dilation' or 'erosion' or 'random'. + :param return_mask: if operation is erosion and the input of this layer is a label map, we have the + choice to either return the eroded label map or the mask (return_mask=True) + """ + + def __init__( + self, + min_factor, + max_factor, + max_factor_dilate=None, + prob=1, + operation="random", + return_mask=False, + **kwargs + ): + + self.min_factor = min_factor + self.max_factor = max_factor + self.max_factor_dilate = ( + max_factor_dilate if max_factor_dilate is not None else self.max_factor + ) + self.prob = prob + self.operation = operation + self.return_mask = return_mask + self.n_dims = None + self.inshape = None + self.n_channels = None + self.convnd = None + super(RandomDilationErosion, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["min_factor"] = self.min_factor + config["max_factor"] = self.max_factor + config["max_factor_dilate"] = self.max_factor_dilate + config["prob"] = self.prob + config["operation"] = self.operation + config["return_mask"] = self.return_mask + return config + + def build(self, input_shape): + + # input shape + self.inshape = input_shape + self.n_dims = len(self.inshape) - 2 + self.n_channels = self.inshape[-1] + + # prepare convolution + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) + + self.built = True + super(RandomDilationErosion, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # sample probability of applying operation. If random negative is erosion and positive is dilation + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype="int32")], axis=0) + if self.operation == "dilation": + prob = tf.random.uniform(shape, 0, 1) + elif self.operation == "erosion": + prob = tf.random.uniform(shape, -1, 0) + elif self.operation == "random": + prob = tf.random.uniform(shape, -1, 1) + else: + raise ValueError( + "operation should either be 'dilation' 'erosion' or 'random', had %s" + % self.operation + ) + + # build kernel + if self.min_factor == self.max_factor: + dist_threshold = self.min_factor * tf.ones(shape, dtype="int32") + else: + if (self.max_factor == self.max_factor_dilate) | ( + self.operation != "random" + ): + dist_threshold = tf.random.uniform( + shape, minval=self.min_factor, maxval=self.max_factor, dtype="int32" + ) + else: + dist_threshold = tf.cast( + tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), + dtype="int32", + ) + kernel = l2i_et.unit_kernel( + dist_threshold, self.n_dims, max_dist_threshold=self.max_factor + ) + + # convolve input mask with kernel according to given probability + mask = tf.cast(tf.cast(inputs, dtype="bool"), dtype="float32") + mask = tf.map_fn(self._single_blur, [mask, kernel, prob], dtype=tf.float32) + mask = tf.cast(mask, "bool") + + if self.return_mask: + return mask + else: + return inputs * tf.cast(mask, dtype=inputs.dtype) + + def _sample_factor(self, inputs): + return tf.cast( + K.switch( + K.less(tf.squeeze(inputs[0]), 0), + tf.random.uniform( + (1,), self.min_factor, self.max_factor, dtype="int32" + ), + tf.random.uniform( + (1,), self.min_factor, self.max_factor_dilate, dtype="int32" + ), + ), + dtype="float32", + ) + + def _single_blur(self, inputs): + # dilate... + new_mask = K.switch( + K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001), + tf.cast( + tf.greater( + tf.squeeze( + self.convnd( + tf.expand_dims(inputs[0], 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ), + axis=0, + ), + 0.01, + ), + dtype="float32", + ), + inputs[0], + ) + # ...or erode + new_mask = K.switch( + K.less(tf.squeeze(inputs[2]), -(1 - self.prob + 0.001)), + 1 + - tf.cast( + tf.greater( + tf.squeeze( + self.convnd( + tf.expand_dims(1 - new_mask, 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ), + axis=0, + ), + 0.01, + ), + dtype="float32", + ), + new_mask, + ) + return new_mask + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/nobrainer/ext/lab2im/utils.py b/nobrainer/ext/lab2im/utils.py new file mode 100644 index 00000000..66e5c03f --- /dev/null +++ b/nobrainer/ext/lab2im/utils.py @@ -0,0 +1,1391 @@ +""" +This file contains all the utilities used in that project. They are classified in 5 categories: +1- loading/saving functions: + -load_volume + -save_volume + -get_volume_info + -get_list_labels + -load_array_if_path + -write_pickle + -read_pickle + -write_model_summary +2- reformatting functions + -reformat_to_list + -reformat_to_n_channels_array +3- path related functions + -list_images_in_folder + -list_files + -list_subfolders + -strip_extension + -strip_suffix + -mkdir + -mkcmd +4- shape-related functions + -get_dims + -get_resample_shape + -add_axis + -get_padding_margin +5- build affine matrices/tensors + -create_affine_transformation_matrix + -sample_affine_transform + -create_rotation_transform + -create_shearing_transform +6- miscellaneous + -infer + -LoopInfo + -get_mapping_lut + -build_training_generator + -find_closest_number_divisible_by_m + -build_binary_structure + -draw_value_from_distribution + -build_exp + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + +from datetime import timedelta +import glob +import math +import os +import pickle +import time + +import keras.backend as K +import keras.layers as KL +import nibabel as nib +import numpy as np +from scipy.ndimage.morphology import distance_transform_edt +import tensorflow as tf + +# ---------------------------------------------- loading/saving functions ---------------------------------------------- + + +def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None): + """ + Load volume file. + :param path_volume: path of the volume to load. Can either be a nii, nii.gz, mgz, or npz format. + If npz format, 1) the variable name is assumed to be 'vol_data', + 2) the volume is associated with an identity affine matrix and blank header. + :param im_only: (optional) if False, the function also returns the affine matrix and header of the volume. + :param squeeze: (optional) whether to squeeze the volume when loading. + :param dtype: (optional) if not None, convert the loaded volume to this numpy dtype. + :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix. + The returned affine matrix is also given in this new space. Must be a numpy array of dimension 4x4. + :return: the volume, with corresponding affine matrix and header if im_only is False. + """ + assert path_volume.endswith((".nii", ".nii.gz", ".mgz", ".npz")), ( + "Unknown data file: %s" % path_volume + ) + + if path_volume.endswith((".nii", ".nii.gz", ".mgz")): + x = nib.load(path_volume) + if squeeze: + volume = np.squeeze(x.get_fdata()) + else: + volume = x.get_fdata() + aff = x.affine + header = x.header + else: # npz + volume = np.load(path_volume)["vol_data"] + if squeeze: + volume = np.squeeze(volume) + aff = np.eye(4) + header = nib.Nifti1Header() + if dtype is not None: + if "int" in dtype: + volume = np.round(volume) + volume = volume.astype(dtype=dtype) + + # align image to reference affine matrix + if aff_ref is not None: + from nobrainer.ext.lab2im import ( # the import is done here to avoid import loops + edit_volumes, + ) + + n_dims, _ = get_dims(list(volume.shape), max_channels=10) + volume, aff = edit_volumes.align_volume_to_ref( + volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims + ) + + if im_only: + return volume + else: + return volume, aff, header + + +def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): + """ + Save a volume. + :param volume: volume to save + :param aff: affine matrix of the volume to save. If aff is None, the volume is saved with an identity affine matrix. + aff can also be set to 'FS', in which case the volume is saved with the affine matrix of FreeSurfer outputs. + :param header: header of the volume to save. If None, the volume is saved with a blank header. + :param path: path where to save the volume. + :param res: (optional) update the resolution in the header before saving the volume. + :param dtype: (optional) numpy dtype for the saved volume. + :param n_dims: (optional) number of dimensions, to avoid confusion in multi-channel case. Default is None, where + n_dims is automatically inferred. + """ + + mkdir(os.path.dirname(path)) + if ".npz" in path: + np.savez_compressed(path, vol_data=volume) + else: + if header is None: + header = nib.Nifti1Header() + if isinstance(aff, str): + if aff == "FS": + aff = np.array( + [[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]] + ) + elif aff is None: + aff = np.eye(4) + if dtype is not None: + if "int" in dtype: + volume = np.round(volume) + volume = volume.astype(dtype=dtype) + nifty = nib.Nifti1Image(volume, aff, header) + nifty.set_data_dtype(dtype) + else: + nifty = nib.Nifti1Image(volume, aff, header) + if res is not None: + if n_dims is None: + n_dims, _ = get_dims(volume.shape) + res = reformat_to_list(res, length=n_dims, dtype=None) + nifty.header.set_zooms(res) + nib.save(nifty, path) + + +def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10): + """ + Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution. + :param path_volume: path of the volume to get information form. + :param return_volume: (optional) whether to return the volume along with the information. + :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix. + All info relative to the volume is then given in this new space. Must be a numpy array of dimension 4x4. + :param max_channels: maximum possible number of channels for the input volume. + :return: volume (if return_volume is true), and corresponding info. If aff_ref is not None, the returned aff is + the original one, i.e. the affine of the image before being aligned to aff_ref. + """ + # read image + im, aff, header = load_volume(path_volume, im_only=False) + + # understand if image is multichannel + im_shape = list(im.shape) + n_dims, n_channels = get_dims(im_shape, max_channels=max_channels) + im_shape = im_shape[:n_dims] + + # get labels res + if ".nii" in path_volume: + data_res = np.array(header["pixdim"][1 : n_dims + 1]) + elif ".mgz" in path_volume: + data_res = np.array(header["delta"]) # mgz image + else: + data_res = np.array([1.0] * n_dims) + + # align to given affine matrix + if aff_ref is not None: + from nobrainer.ext.lab2im import ( # the import is done here to avoid import loops + edit_volumes, + ) + + ras_axes = edit_volumes.get_ras_axes(aff, n_dims=n_dims) + ras_axes_ref = edit_volumes.get_ras_axes(aff_ref, n_dims=n_dims) + im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims) + im_shape = np.array(im_shape) + data_res = np.array(data_res) + im_shape[ras_axes_ref] = im_shape[ras_axes] + data_res[ras_axes_ref] = data_res[ras_axes] + im_shape = im_shape.tolist() + + # return info + if return_volume: + return im, im_shape, aff, n_dims, n_channels, header, data_res + else: + return im_shape, aff, n_dims, n_channels, header, data_res + + +def get_list_labels( + label_list=None, labels_dir=None, save_label_list=None, FS_sort=False +): + """This function reads or computes a list of all label values used in a set of label maps. + It can also sort all labels according to FreeSurfer lut. + :param label_list: (optional) already computed label_list. Can be a sequence, a 1d numpy array, or the path to + a numpy 1d array. + :param labels_dir: (optional) if path_label_list is None, the label list is computed by reading all the label maps + in the given folder. Can also be the path to a single label map. + :param save_label_list: (optional) path where to save the label list. + :param FS_sort: (optional) whether to sort label values according to the FreeSurfer classification. + If true, the label values will be ordered as follows: neutral labels first (i.e. non-sided), left-side labels, + and right-side labels. If FS_sort is True, this function also returns the number of neutral labels in label_list. + :return: the label list (numpy 1d array), and the number of neutral (i.e. non-sided) labels if FS_sort is True. + If one side of the brain is not represented at all in label_list, all labels are considered as neutral, and + n_neutral_labels = len(label_list). + """ + + # load label list if previously computed + if label_list is not None: + label_list = np.array( + reformat_to_list(label_list, load_as_numpy=True, dtype="int") + ) + + # compute label list from all label files + elif labels_dir is not None: + print("Compiling list of unique labels") + # go through all labels files and compute unique list of labels + labels_paths = list_images_in_folder(labels_dir) + label_list = np.empty(0) + loop_info = LoopInfo(len(labels_paths), 10, "processing", print_time=True) + for lab_idx, path in enumerate(labels_paths): + loop_info.update(lab_idx) + y = load_volume(path, dtype="int32") + y_unique = np.unique(y) + label_list = np.unique(np.concatenate((label_list, y_unique))).astype("int") + + else: + raise Exception( + "either label_list, path_label_list or labels_dir should be provided" + ) + + # sort labels in neutral/left/right according to FS labels + n_neutral_labels = 0 + if FS_sort: + neutral_FS_labels = [ + 0, + 14, + 15, + 16, + 21, + 22, + 23, + 24, + 72, + 77, + 80, + 85, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 165, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 251, + 252, + 253, + 254, + 255, + 258, + 259, + 260, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 502, + 506, + 507, + 508, + 509, + 511, + 512, + 514, + 515, + 516, + 517, + 530, + 531, + 532, + 533, + 534, + 535, + 536, + 537, + ] + neutral = list() + left = list() + right = list() + for la in label_list: + if la in neutral_FS_labels: + if la not in neutral: + neutral.append(la) + elif ( + (0 < la < 14) + | (16 < la < 21) + | (24 < la < 40) + | (135 < la < 139) + | (1000 <= la <= 1035) + | (la == 865) + | (20100 < la < 20110) + ): + if la not in left: + left.append(la) + elif ( + (39 < la < 72) + | (162 < la < 165) + | (2000 <= la <= 2035) + | (20000 < la < 20010) + | (la == 139) + | (la == 866) + ): + if la not in right: + right.append(la) + else: + raise Exception( + "label {} not in our current FS classification, " + "please update get_list_labels in utils.py".format(la) + ) + label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)]) + if ((len(left) > 0) & (len(right) > 0)) | ( + (len(left) == 0) & (len(right) == 0) + ): + n_neutral_labels = len(neutral) + else: + n_neutral_labels = len(label_list) + + # save labels if specified + if save_label_list is not None: + np.save(save_label_list, np.int32(label_list)) + + if FS_sort: + return np.int32(label_list), n_neutral_labels + else: + return np.int32(label_list), None + + +def load_array_if_path(var, load_as_numpy=True): + """If var is a string and load_as_numpy is True, this function loads the array writen at the path indicated by var. + Otherwise it simply returns var as it is.""" + if (isinstance(var, str)) & load_as_numpy: + assert os.path.isfile(var), "No such path: %s" % var + var = np.load(var) + return var + + +def write_pickle(filepath, obj): + """write a python object with a pickle at a given path""" + with open(filepath, "wb") as file: + pickler = pickle.Pickler(file) + pickler.dump(obj) + + +def read_pickle(filepath): + """read a python object with a pickle""" + with open(filepath, "rb") as file: + unpickler = pickle.Unpickler(file) + return unpickler.load() + + +def write_model_summary(model, filepath="./model_summary.txt", line_length=150): + """Write the summary of a keras model at a given path, with a given length for each line""" + with open(filepath, "w") as fh: + model.summary(print_fn=lambda x: fh.write(x + "\n"), line_length=line_length) + + +# ----------------------------------------------- reformatting functions ----------------------------------------------- + + +def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): + """This function takes a variable and reformat it into a list of desired + length and type (int, float, bool, str). + If variable is a string, and load_as_numpy is True, it will be loaded as a numpy array. + If variable is None, this function returns None. + :param var: a str, int, float, list, tuple, or numpy array + :param length: (optional) if var is a single item, it will be replicated to a list of this length + :param load_as_numpy: (optional) whether var is the path to a numpy array + :param dtype: (optional) convert all item to this type. Can be 'int', 'float', 'bool', or 'str' + :return: reformatted list + """ + + # convert to list + if var is None: + return None + var = load_array_if_path(var, load_as_numpy=load_as_numpy) + if isinstance(var, (int, float, np.int32, np.int64, np.float32, np.float64)): + var = [var] + elif isinstance(var, tuple): + var = list(var) + elif isinstance(var, np.ndarray): + if var.shape == (1,): + var = [var[0]] + else: + var = np.squeeze(var).tolist() + elif isinstance(var, str): + var = [var] + elif isinstance(var, bool): + var = [var] + if isinstance(var, list): + if length is not None: + if len(var) == 1: + var = var * length + elif len(var) != length: + raise ValueError( + "if var is a list/tuple/numpy array, it should be of length 1 or {0}, " + "had {1}".format(length, var) + ) + else: + raise TypeError( + "var should be an int, float, tuple, list, numpy array, or path to numpy array" + ) + + # convert items type + if dtype is not None: + if dtype == "int": + var = [int(v) for v in var] + elif dtype == "float": + var = [float(v) for v in var] + elif dtype == "bool": + var = [bool(v) for v in var] + elif dtype == "str": + var = [str(v) for v in var] + else: + raise ValueError( + "dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype) + ) + return var + + +def reformat_to_n_channels_array(var, n_dims=3, n_channels=1): + """This function takes an int, float, list or tuple and reformat it to an array of shape (n_channels, n_dims). + If resolution is a str, it will be assumed to be the path of a numpy array. + If resolution is a numpy array, it will be checked to have shape (n_channels, n_dims). + Finally if resolution is None, this function returns None as well.""" + if var is None: + return [None] * n_channels + if isinstance(var, str): + var = np.load(var) + # convert to numpy array + if isinstance(var, (int, float, list, tuple)): + var = reformat_to_list(var, n_dims) + var = np.tile(np.array(var), (n_channels, 1)) + # check shape if numpy array + elif isinstance(var, np.ndarray): + if n_channels == 1: + var = var.reshape((1, n_dims)) + else: + if np.squeeze(var).shape == (n_dims,): + var = np.tile(var.reshape((1, n_dims)), (n_channels, 1)) + elif var.shape != (n_channels, n_dims): + raise ValueError( + "if array, var should be {0} or {1}".format( + (1, n_dims), (n_channels, n_dims) + ) + ) + else: + raise TypeError("var should be int, float, list, tuple or ndarray") + return np.round(var, 3) + + +# ----------------------------------------------- path-related functions ----------------------------------------------- + + +def list_images_in_folder(path_dir, include_single_image=True, check_if_empty=True): + """List all files with extension nii, nii.gz, mgz, or npz within a folder.""" + basename = os.path.basename(path_dir) + if include_single_image & ( + (".nii.gz" in basename) + | (".nii" in basename) + | (".mgz" in basename) + | (".npz" in basename) + ): + assert os.path.isfile(path_dir), "file %s does not exist" % path_dir + list_images = [path_dir] + else: + if os.path.isdir(path_dir): + list_images = sorted( + glob.glob(os.path.join(path_dir, "*nii.gz")) + + glob.glob(os.path.join(path_dir, "*nii")) + + glob.glob(os.path.join(path_dir, "*.mgz")) + + glob.glob(os.path.join(path_dir, "*.npz")) + ) + else: + raise Exception("Folder does not exist: %s" % path_dir) + if check_if_empty: + assert len(list_images) > 0, ( + "no .nii, .nii.gz, .mgz or .npz image could be found in %s" % path_dir + ) + return list_images + + +def list_files(path_dir, whole_path=True, expr=None, cond_type="or"): + """This function returns a list of files contained in a folder, with possible regexp. + :param path_dir: path of a folder + :param whole_path: (optional) whether to return whole path or just the filenames. + :param expr: (optional) regexp for files to list. Can be a str or a list of str. + :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp. + Can be 'or', or 'and'. + :return: a list of files + """ + assert isinstance(whole_path, bool), "whole_path should be bool" + assert cond_type in ["or", "and"], "cond_type should be either 'or', or 'and'" + if whole_path: + files_list = sorted( + [ + os.path.join(path_dir, f) + for f in os.listdir(path_dir) + if os.path.isfile(os.path.join(path_dir, f)) + ] + ) + else: + files_list = sorted( + [ + f + for f in os.listdir(path_dir) + if os.path.isfile(os.path.join(path_dir, f)) + ] + ) + if expr is not None: # assumed to be either str or list of str + if isinstance(expr, str): + expr = [expr] + elif not isinstance(expr, (list, tuple)): + raise Exception( + "if specified, 'expr' should be a string or list of strings." + ) + matched_list_files = list() + for match in expr: + tmp_matched_files_list = sorted( + [f for f in files_list if match in os.path.basename(f)] + ) + if cond_type == "or": + files_list = [f for f in files_list if f not in tmp_matched_files_list] + matched_list_files += tmp_matched_files_list + elif cond_type == "and": + files_list = tmp_matched_files_list + matched_list_files = tmp_matched_files_list + files_list = sorted(matched_list_files) + return files_list + + +def list_subfolders(path_dir, whole_path=True, expr=None, cond_type="or"): + """This function returns a list of subfolders contained in a folder, with possible regexp. + :param path_dir: path of a folder + :param whole_path: (optional) whether to return whole path or just the subfolder names. + :param expr: (optional) regexp for files to list. Can be a str or a list of str. + :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp. + Can be 'or', or 'and'. + :return: a list of subfolders + """ + assert isinstance(whole_path, bool), "whole_path should be bool" + assert cond_type in ["or", "and"], "cond_type should be either 'or', or 'and'" + if whole_path: + subdirs_list = sorted( + [ + os.path.join(path_dir, f) + for f in os.listdir(path_dir) + if os.path.isdir(os.path.join(path_dir, f)) + ] + ) + else: + subdirs_list = sorted( + [ + f + for f in os.listdir(path_dir) + if os.path.isdir(os.path.join(path_dir, f)) + ] + ) + if expr is not None: # assumed to be either str or list of str + if isinstance(expr, str): + expr = [expr] + elif not isinstance(expr, (list, tuple)): + raise Exception( + "if specified, 'expr' should be a string or list of strings." + ) + matched_list_subdirs = list() + for match in expr: + tmp_matched_list_subdirs = sorted( + [f for f in subdirs_list if match in os.path.basename(f)] + ) + if cond_type == "or": + subdirs_list = [ + f for f in subdirs_list if f not in tmp_matched_list_subdirs + ] + matched_list_subdirs += tmp_matched_list_subdirs + elif cond_type == "and": + subdirs_list = tmp_matched_list_subdirs + matched_list_subdirs = tmp_matched_list_subdirs + subdirs_list = sorted(matched_list_subdirs) + return subdirs_list + + +def get_image_extension(path): + name = os.path.basename(path) + if name[-7:] == ".nii.gz": + return "nii.gz" + elif name[-4:] == ".mgz": + return "mgz" + elif name[-4:] == ".nii": + return "nii" + elif name[-4:] == ".npz": + return "npz" + + +def strip_extension(path): + """Strip classical image extensions (.nii.gz, .nii, .mgz, .npz) from a filename.""" + return ( + path.replace(".nii.gz", "") + .replace(".nii", "") + .replace(".mgz", "") + .replace(".npz", "") + ) + + +def strip_suffix(path): + """Strip classical image suffix from a filename.""" + path = path.replace("_aseg", "") + path = path.replace("aseg", "") + path = path.replace(".aseg", "") + path = path.replace("_aseg_1", "") + path = path.replace("_aseg_2", "") + path = path.replace("aseg_1_", "") + path = path.replace("aseg_2_", "") + path = path.replace("_orig", "") + path = path.replace("orig", "") + path = path.replace(".orig", "") + path = path.replace("_norm", "") + path = path.replace("norm", "") + path = path.replace(".norm", "") + path = path.replace("_talairach", "") + path = path.replace("GSP_FS_4p5", "GSP") + path = path.replace(".nii_crispSegmentation", "") + path = path.replace("_crispSegmentation", "") + path = path.replace("_seg", "") + path = path.replace(".seg", "") + path = path.replace("seg", "") + path = path.replace("_seg_1", "") + path = path.replace("_seg_2", "") + path = path.replace("seg_1_", "") + path = path.replace("seg_2_", "") + return path + + +def mkdir(path_dir): + """Recursively creates the current dir as well as its parent folders if they do not already exist.""" + if path_dir[-1] == "/": + path_dir = path_dir[:-1] + if not os.path.isdir(path_dir): + list_dir_to_create = [path_dir] + while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])): + list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1])) + for dir_to_create in reversed(list_dir_to_create): + os.mkdir(dir_to_create) + + +def mkcmd(*args): + """Creates terminal command with provided inputs. + Example: mkcmd('mv', 'source', 'dest') will give 'mv source dest'.""" + return " ".join([str(arg) for arg in args]) + + +# ---------------------------------------------- shape-related functions ----------------------------------------------- + + +def get_dims(shape, max_channels=10): + """Get the number of dimensions and channels from the shape of an array. + The number of dimensions is assumed to be the length of the shape, as long as the shape of the last dimension is + inferior or equal to max_channels (default 3). + :param shape: shape of an array. Can be a sequence or a 1d numpy array. + :param max_channels: maximum possible number of channels. + :return: the number of dimensions and channels associated with the provided shape. + example 1: get_dims([150, 150, 150], max_channels=10) = (3, 1) + example 2: get_dims([150, 150, 150, 3], max_channels=10) = (3, 3) + example 3: get_dims([150, 150, 150, 15], max_channels=10) = (4, 1), because 5>3""" + if shape[-1] <= max_channels: + n_dims = len(shape) - 1 + n_channels = shape[-1] + else: + n_dims = len(shape) + n_channels = 1 + return n_dims, n_channels + + +def get_resample_shape(patch_shape, factor, n_channels=None): + """Compute the shape of a resampled array given a shape factor. + :param patch_shape: size of the initial array (without number of channels). + :param factor: resampling factor. Can be a number, sequence, or 1d numpy array. + :param n_channels: (optional) if not None, add a number of channel at the end of the computed shape. + :return: list containing the shape of the input array after being resampled by the given factor. + """ + factor = reformat_to_list(factor, length=len(patch_shape)) + shape = [math.ceil(patch_shape[i] * factor[i]) for i in range(len(patch_shape))] + if n_channels is not None: + shape += [n_channels] + return shape + + +def add_axis(x, axis=0): + """Add axis to a numpy array. + :param x: input array + :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time. + """ + axis = reformat_to_list(axis) + for ax in axis: + x = np.expand_dims(x, axis=ax) + return x + + +def get_padding_margin(cropping, loss_cropping): + """Compute padding margin""" + if (cropping is not None) & (loss_cropping is not None): + cropping = reformat_to_list(cropping) + loss_cropping = reformat_to_list(loss_cropping) + n_dims = max(len(cropping), len(loss_cropping)) + cropping = reformat_to_list(cropping, length=n_dims) + loss_cropping = reformat_to_list(loss_cropping, length=n_dims) + padding_margin = [ + int((cropping[i] - loss_cropping[i]) / 2) for i in range(n_dims) + ] + if len(padding_margin) == 1: + padding_margin = padding_margin[0] + else: + padding_margin = None + return padding_margin + + +# -------------------------------------------- build affine matrices/tensors ------------------------------------------- + + +def create_affine_transformation_matrix( + n_dims, scaling=None, rotation=None, shearing=None, translation=None +): + """Create a 4x4 affine transformation matrix from specified values + :param n_dims: integer, can either be 2 or 3. + :param scaling: list of 3 scaling values + :param rotation: list of 3 angles (degrees) for rotations around 1st, 2nd, 3rd axis + :param shearing: list of 6 shearing values + :param translation: list of 3 values + :return: 4x4 numpy matrix + """ + + T_scaling = np.eye(n_dims + 1) + T_shearing = np.eye(n_dims + 1) + T_translation = np.eye(n_dims + 1) + + if scaling is not None: + T_scaling[np.arange(n_dims + 1), np.arange(n_dims + 1)] = np.append(scaling, 1) + + if shearing is not None: + shearing_index = np.ones((n_dims + 1, n_dims + 1), dtype="bool") + shearing_index[np.eye(n_dims + 1, dtype="bool")] = False + shearing_index[-1, :] = np.zeros((n_dims + 1)) + shearing_index[:, -1] = np.zeros((n_dims + 1)) + T_shearing[shearing_index] = shearing + + if translation is not None: + T_translation[np.arange(n_dims), n_dims * np.ones(n_dims, dtype="int")] = ( + translation + ) + + if n_dims == 2: + if rotation is None: + rotation = np.zeros(1) + else: + rotation = np.asarray(rotation) * (math.pi / 180) + T_rot = np.eye(n_dims + 1) + T_rot[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [ + np.cos(rotation[0]), + np.sin(rotation[0]), + np.sin(rotation[0]) * -1, + np.cos(rotation[0]), + ] + return T_translation @ T_rot @ T_shearing @ T_scaling + + else: + + if rotation is None: + rotation = np.zeros(n_dims) + else: + rotation = np.asarray(rotation) * (math.pi / 180) + T_rot1 = np.eye(n_dims + 1) + T_rot1[np.array([1, 2, 1, 2]), np.array([1, 1, 2, 2])] = [ + np.cos(rotation[0]), + np.sin(rotation[0]), + np.sin(rotation[0]) * -1, + np.cos(rotation[0]), + ] + T_rot2 = np.eye(n_dims + 1) + T_rot2[np.array([0, 2, 0, 2]), np.array([0, 0, 2, 2])] = [ + np.cos(rotation[1]), + np.sin(rotation[1]) * -1, + np.sin(rotation[1]), + np.cos(rotation[1]), + ] + T_rot3 = np.eye(n_dims + 1) + T_rot3[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [ + np.cos(rotation[2]), + np.sin(rotation[2]), + np.sin(rotation[2]) * -1, + np.cos(rotation[2]), + ] + return T_translation @ T_rot3 @ T_rot2 @ T_rot1 @ T_shearing @ T_scaling + + +def sample_affine_transform( + batchsize, + n_dims, + rotation_bounds=False, + scaling_bounds=False, + shearing_bounds=False, + translation_bounds=False, + enable_90_rotations=False, +): + """build batchsize x 4 x 4 tensor representing an affine transformation in homogeneous coordinates. + If return_inv is True, also returns the inverse of the created affine matrix.""" + + if (rotation_bounds is not False) | (enable_90_rotations is not False): + if n_dims == 2: + if rotation_bounds is not False: + rotation = draw_value_from_distribution( + rotation_bounds, + size=1, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize, + ) + else: + rotation = tf.zeros( + tf.concat([batchsize, tf.ones(1, dtype="int32")], axis=0) + ) + else: # n_dims = 3 + if rotation_bounds is not False: + rotation = draw_value_from_distribution( + rotation_bounds, + size=n_dims, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize, + ) + else: + rotation = tf.zeros( + tf.concat([batchsize, 3 * tf.ones(1, dtype="int32")], axis=0) + ) + if enable_90_rotations: + rotation = ( + tf.cast( + tf.random.uniform(tf.shape(rotation), maxval=4, dtype="int32") * 90, + "float32", + ) + + rotation + ) + T_rot = create_rotation_transform(rotation, n_dims) + else: + T_rot = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) + + if shearing_bounds is not False: + shearing = draw_value_from_distribution( + shearing_bounds, + size=n_dims**2 - n_dims, + default_range=0.01, + return_as_tensor=True, + batchsize=batchsize, + ) + T_shearing = create_shearing_transform(shearing, n_dims) + else: + T_shearing = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) + + if scaling_bounds is not False: + scaling = draw_value_from_distribution( + scaling_bounds, + size=n_dims, + centre=1, + default_range=0.15, + return_as_tensor=True, + batchsize=batchsize, + ) + T_scaling = tf.linalg.diag(scaling) + else: + T_scaling = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) + + T = tf.matmul(T_scaling, tf.matmul(T_shearing, T_rot)) + + if translation_bounds is not False: + translation = draw_value_from_distribution( + translation_bounds, + size=n_dims, + default_range=5, + return_as_tensor=True, + batchsize=batchsize, + ) + T = tf.concat([T, tf.expand_dims(translation, axis=-1)], axis=-1) + else: + T = tf.concat( + [T, tf.zeros(tf.concat([tf.shape(T)[:2], tf.ones(1, dtype="int32")], 0))], + axis=-1, + ) + + # build rigid transform + T_last_row = tf.expand_dims( + tf.concat([tf.zeros((1, n_dims)), tf.ones((1, 1))], axis=1), 0 + ) + T_last_row = tf.tile( + T_last_row, tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0) + ) + T = tf.concat([T, T_last_row], axis=1) + + return T + + +def create_rotation_transform(rotation, n_dims): + """build rotation transform from 3d or 2d rotation coefficients. Angles are given in degrees.""" + rotation = rotation * np.pi / 180 + if n_dims == 3: + shape = tf.shape(tf.expand_dims(rotation[..., 0], -1)) + + Rx_row0 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([1.0, 0.0, 0.0]), 0), shape), + axis=1, + ) + Rx_row1 = tf.stack( + [ + tf.zeros(shape), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(-tf.sin(rotation[..., 0]), -1), + ], + axis=-1, + ) + Rx_row2 = tf.stack( + [ + tf.zeros(shape), + tf.expand_dims(tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + ], + axis=-1, + ) + Rx = tf.concat([Rx_row0, Rx_row1, Rx_row2], axis=1) + + Ry_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 1]), -1), + tf.zeros(shape), + tf.expand_dims(tf.sin(rotation[..., 1]), -1), + ], + axis=-1, + ) + Ry_row1 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([0.0, 1.0, 0.0]), 0), shape), + axis=1, + ) + Ry_row2 = tf.stack( + [ + tf.expand_dims(-tf.sin(rotation[..., 1]), -1), + tf.zeros(shape), + tf.expand_dims(tf.cos(rotation[..., 1]), -1), + ], + axis=-1, + ) + Ry = tf.concat([Ry_row0, Ry_row1, Ry_row2], axis=1) + + Rz_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 2]), -1), + tf.expand_dims(-tf.sin(rotation[..., 2]), -1), + tf.zeros(shape), + ], + axis=-1, + ) + Rz_row1 = tf.stack( + [ + tf.expand_dims(tf.sin(rotation[..., 2]), -1), + tf.expand_dims(tf.cos(rotation[..., 2]), -1), + tf.zeros(shape), + ], + axis=-1, + ) + Rz_row2 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([0.0, 0.0, 1.0]), 0), shape), + axis=1, + ) + Rz = tf.concat([Rz_row0, Rz_row1, Rz_row2], axis=1) + + T_rot = tf.matmul(tf.matmul(Rx, Ry), Rz) + + elif n_dims == 2: + R_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(tf.sin(rotation[..., 0]), -1), + ], + axis=-1, + ) + R_row1 = tf.stack( + [ + tf.expand_dims(-tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + ], + axis=-1, + ) + T_rot = tf.concat([R_row0, R_row1], axis=1) + + else: + raise Exception("only supports 2 or 3D.") + + return T_rot + + +def create_shearing_transform(shearing, n_dims): + """build shearing transform from 2d/3d shearing coefficients""" + shape = tf.shape(tf.expand_dims(shearing[..., 0], -1)) + if n_dims == 3: + shearing_row0 = tf.stack( + [ + tf.ones(shape), + tf.expand_dims(shearing[..., 0], -1), + tf.expand_dims(shearing[..., 1], -1), + ], + axis=-1, + ) + shearing_row1 = tf.stack( + [ + tf.expand_dims(shearing[..., 2], -1), + tf.ones(shape), + tf.expand_dims(shearing[..., 3], -1), + ], + axis=-1, + ) + shearing_row2 = tf.stack( + [ + tf.expand_dims(shearing[..., 4], -1), + tf.expand_dims(shearing[..., 5], -1), + tf.ones(shape), + ], + axis=-1, + ) + T_shearing = tf.concat([shearing_row0, shearing_row1, shearing_row2], axis=1) + + elif n_dims == 2: + shearing_row0 = tf.stack( + [tf.ones(shape), tf.expand_dims(shearing[..., 0], -1)], axis=-1 + ) + shearing_row1 = tf.stack( + [tf.expand_dims(shearing[..., 1], -1), tf.ones(shape)], axis=-1 + ) + T_shearing = tf.concat([shearing_row0, shearing_row1], axis=1) + else: + raise Exception("only supports 2 or 3D.") + return T_shearing + + +# --------------------------------------------------- miscellaneous ---------------------------------------------------- + + +def infer(x): + """Try to parse input to float. If it fails, tries boolean, and otherwise keep it as string""" + try: + x = float(x) + except ValueError: + if x == "False": + x = False + elif x == "True": + x = True + elif not isinstance(x, str): + raise TypeError( + "input should be an int/float/boolean/str, had {}".format(type(x)) + ) + return x + + +class LoopInfo: + """ + Class to print the current iteration in a for loop, and optionally the estimated remaining time. + Instantiate just before the loop, and call the update method at the start of the loop. + The printed text has the following format: + processing i/total remaining time: hh:mm:ss + """ + + def __init__(self, n_iterations, spacing=10, text="processing", print_time=False): + """ + :param n_iterations: total number of iterations of the for loop. + :param spacing: frequency at which the update info will be printed on screen. + :param text: text to print. Default is processing. + :param print_time: whether to print the estimated remaining time. Default is False. + """ + + # loop parameters + self.n_iterations = n_iterations + self.spacing = spacing + + # text parameters + self.text = text + self.print_time = print_time + self.print_previous_time = False + self.align = len(str(self.n_iterations)) * 2 + 1 + 3 + + # timing parameters + self.iteration_durations = np.zeros((n_iterations,)) + self.start = time.time() + self.previous = time.time() + + def update(self, idx): + + # time iteration + now = time.time() + self.iteration_durations[idx] = now - self.previous + self.previous = now + + # print text + if idx == 0: + print(self.text + " 1/{}".format(self.n_iterations)) + elif idx % self.spacing == self.spacing - 1: + iteration = str(idx + 1) + "/" + str(self.n_iterations) + if self.print_time: + # estimate remaining time + max_duration = np.max(self.iteration_durations) + average_duration = np.mean( + self.iteration_durations[ + self.iteration_durations > 0.01 * max_duration + ] + ) + remaining_time = int(average_duration * (self.n_iterations - idx)) + # print total remaining time only if it is greater than 1s or if it was previously printed + if (remaining_time > 1) | self.print_previous_time: + eta = str(timedelta(seconds=remaining_time)) + print( + self.text + + " {:<{x}} remaining time: {}".format( + iteration, eta, x=self.align + ) + ) + self.print_previous_time = True + else: + print(self.text + " {}".format(iteration)) + else: + print(self.text + " {}".format(iteration)) + + +def get_mapping_lut(source, dest=None): + """This functions returns the look-up table to map a list of N values (source) to another list (dest). + If the second list is not given, we assume it is equal to [0, ..., N-1].""" + + # initialise + source = np.array(reformat_to_list(source), dtype="int32") + n_labels = source.shape[0] + + # build new label list if necessary + if dest is None: + dest = np.arange(n_labels, dtype="int32") + else: + assert len(source) == len( + dest + ), "label_list and new_label_list should have the same length" + dest = np.array(reformat_to_list(dest, dtype="int")) + + # build look-up table + lut = np.zeros(np.max(source) + 1, dtype="int32") + for source, dest in zip(source, dest): + lut[source] = dest + + return lut + + +def build_training_generator(gen, batchsize): + """Build generator for training a network.""" + while True: + inputs = next(gen) + if batchsize > 1: + target = np.concatenate([np.zeros((1, 1))] * batchsize, 0) + else: + target = np.zeros((1, 1)) + yield inputs, target + + +def find_closest_number_divisible_by_m(n, m, answer_type="lower"): + """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns + values lower than n), or 'higher' (only returns values higher than m).""" + if n % m == 0: + return n + else: + q = int(n / m) + lower = q * m + higher = (q + 1) * m + if answer_type == "lower": + return lower + elif answer_type == "higher": + return higher + elif answer_type == "closer": + return lower if (n - lower) < (higher - n) else higher + else: + raise Exception( + "answer_type should be lower, higher, or closer, had : %s" % answer_type + ) + + +def build_binary_structure(connectivity, n_dims, shape=None): + """Return a dilation/erosion element with provided connectivity""" + if shape is None: + shape = [connectivity * 2 + 1] * n_dims + else: + shape = reformat_to_list(shape, length=n_dims) + dist = np.ones(shape) + center = tuple([tuple([int(s / 2)]) for s in shape]) + dist[center] = 0 + dist = distance_transform_edt(dist) + struct = (dist <= connectivity) * 1 + return struct + + +def draw_value_from_distribution( + hyperparameter, + size=1, + distribution="uniform", + centre=0.0, + default_range=10.0, + positive_only=False, + return_as_tensor=False, + batchsize=None, +): + """Sample values from a uniform, or normal distribution of given hyperparameters. + These hyperparameters are to the number of 2 in both uniform and normal cases. + :param hyperparameter: values of the hyperparameters. Can either be: + 1) None, in each case the two hyperparameters are given by [center-default_range, center+default_range], + 2) a number, where the two hyperparameters are given by [centre-hyperparameter, centre+hyperparameter], + 3) a sequence of length 2, directly defining the two hyperparameters: [min, max] if the distribution is uniform, + [mean, std] if the distribution is normal. + 4) a numpy array, with size (2, m). In this case, the function returns a 1d array of size m, where each value has + been sampled independently with the specified hyperparameters. If the distribution is uniform, rows correspond to + its lower and upper bounds, and if the distribution is normal, rows correspond to its mean and std deviation. + 5) a numpy array of size (2*n, m). Same as 4) but we first randomly select a block of two rows among the + n possibilities. + 6) the path to a numpy array corresponding to case 4 or 5. + 7) False, in which case this function returns None. + :param size: (optional) number of values to sample. All values are sampled independently. + Used only if hyperparameter is not a numpy array. + :param distribution: (optional) the distribution type. Can be 'uniform' or 'normal'. Default is 'uniform'. + :param centre: (optional) default centre to use if hyperparameter is None or a number. + :param default_range: (optional) default range to use if hyperparameter is None. + :param positive_only: (optional) whether to reset all negative values to zero. + :param return_as_tensor: (optional) whether to return the result as a tensorflow tensor + :param batchsize: (optional) if return_as_tensor is true, then you can sample a tensor of a given batchsize. Give + this batchsize as a tensorflow tensor here. + :return: a float, or a numpy 1d array if size > 1, or hyperparameter is itself a numpy array. + Returns None if hyperparameter is False. + """ + + # return False is hyperparameter is False + if hyperparameter is False: + return None + + # reformat parameter_range + hyperparameter = load_array_if_path(hyperparameter, load_as_numpy=True) + if not isinstance(hyperparameter, np.ndarray): + if hyperparameter is None: + hyperparameter = np.array( + [[centre - default_range] * size, [centre + default_range] * size] + ) + elif isinstance(hyperparameter, (int, float)): + hyperparameter = np.array( + [[centre - hyperparameter] * size, [centre + hyperparameter] * size] + ) + elif isinstance(hyperparameter, (list, tuple)): + assert ( + len(hyperparameter) == 2 + ), "if list, parameter_range should be of length 2." + hyperparameter = np.transpose(np.tile(np.array(hyperparameter), (size, 1))) + else: + raise ValueError( + "parameter_range should either be None, a number, a sequence, or a numpy array." + ) + elif isinstance(hyperparameter, np.ndarray): + assert ( + hyperparameter.shape[0] % 2 == 0 + ), "number of rows of parameter_range should be divisible by 2" + n_modalities = int(hyperparameter.shape[0] / 2) + modality_idx = 2 * np.random.randint(n_modalities) + hyperparameter = hyperparameter[modality_idx : modality_idx + 2, :] + + # draw values as tensor + if return_as_tensor: + shape = KL.Lambda( + lambda x: tf.convert_to_tensor(hyperparameter.shape[1], "int32") + )([]) + if batchsize is not None: + shape = KL.Lambda( + lambda x: tf.concat([x[0], tf.expand_dims(x[1], axis=0)], axis=0) + )([batchsize, shape]) + if distribution == "uniform": + parameter_value = KL.Lambda( + lambda x: tf.random.uniform( + shape=x, minval=hyperparameter[0, :], maxval=hyperparameter[1, :] + ) + )(shape) + elif distribution == "normal": + parameter_value = KL.Lambda( + lambda x: tf.random.normal( + shape=x, mean=hyperparameter[0, :], stddev=hyperparameter[1, :] + ) + )(shape) + else: + raise ValueError( + "Distribution not supported, should be 'uniform' or 'normal'." + ) + + if positive_only: + parameter_value = KL.Lambda(lambda x: K.clip(x, 0, None))(parameter_value) + + # draw values as numpy array + else: + if distribution == "uniform": + parameter_value = np.random.uniform( + low=hyperparameter[0, :], high=hyperparameter[1, :] + ) + elif distribution == "normal": + parameter_value = np.random.normal( + loc=hyperparameter[0, :], scale=hyperparameter[1, :] + ) + else: + raise ValueError( + "Distribution not supported, should be 'uniform' or 'normal'." + ) + + if positive_only: + parameter_value[parameter_value < 0] = 0 + + return parameter_value + + +def build_exp(x, first, last, fix_point): + # first = f(0), last = f(+inf), fix_point = [x0, f(x0))] + a = last + b = first - last + c = -(1 / fix_point[0]) * np.log((fix_point[1] - last) / (first - last)) + return a + b * np.exp(-c * x) diff --git a/nobrainer/ext/neuron/__init__.py b/nobrainer/ext/neuron/__init__.py new file mode 100644 index 00000000..6a418b95 --- /dev/null +++ b/nobrainer/ext/neuron/__init__.py @@ -0,0 +1 @@ +from . import layers, models, utils diff --git a/nobrainer/ext/neuron/layers.py b/nobrainer/ext/neuron/layers.py new file mode 100644 index 00000000..9f4e1821 --- /dev/null +++ b/nobrainer/ext/neuron/layers.py @@ -0,0 +1,506 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +or for the transformation/integration functions: + +Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration +Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu +MICCAI 2018. + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +from copy import deepcopy + +from keras import backend as K +from keras.layers import Layer + +# third party +import tensorflow as tf + +# local +from nobrainer.ext.neuron.utils import ( + affine_to_shift, + combine_non_linear_and_aff_to_shift, + integrate_vec, + resize, + transform, +) + + +class SpatialTransformer(Layer): + """ + N-D Spatial Transformer Tensorflow / Keras Layer + + The Layer can handle both affine and dense transforms. + Both transforms are meant to give a 'shift' from the current position. + Therefore, a dense transform gives displacements (not absolute locations) at each voxel, + and an affine transform gives the *difference* of the affine matrix from + the identity matrix. + + If you find this function useful, please cite: + Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration + Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu + MICCAI 2018. + + Originally, this code was based on voxelmorph code, which + was in turn transformed to be dense with the help of (affine) STN code + via https://github.com/kevinzakka/spatial-transformer-network + + Since then, we've re-written the code to be generalized to any + dimensions, and along the way wrote grid and interpolation functions + """ + + def __init__( + self, interp_method="linear", indexing="ij", single_transform=False, **kwargs + ): + """ + Parameters: + interp_method: 'linear' or 'nearest' + single_transform: whether a single transform supplied for the whole batch + indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian) + 'xy' indexing will have the first two entries of the flow + (along last axis) flipped compared to 'ij' indexing + """ + self.interp_method = interp_method + self.ndims = None + self.inshape = None + self.single_transform = single_transform + self.is_affine = list() + + assert indexing in [ + "ij", + "xy", + ], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" + self.indexing = indexing + + super(self.__class__, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["interp_method"] = self.interp_method + config["indexing"] = self.indexing + config["single_transform"] = self.single_transform + return config + + def build(self, input_shape): + """ + input_shape should be a list for two inputs: + input1: image. + input2: list of transform Tensors + if affine: + should be an N+1 x N+1 matrix + *or* a N+1*N+1 tensor (which will be reshaped to N x (N+1) and an identity row added) + if not affine: + should be a *vol_shape x N + """ + + if len(input_shape) > 3: + raise Exception( + "Spatial Transformer must be called on a list of min length 2 and max length 3." + "First argument is the image followed by the affine and non linear transforms." + ) + + # set up number of dimensions + self.ndims = len(input_shape[0]) - 2 + self.inshape = input_shape + trf_shape = [trans_shape[1:] for trans_shape in input_shape[1:]] + + for i, shape in enumerate(trf_shape): + + # the transform is an affine iff: + # it's a 1D Tensor [dense transforms need to be at least ndims + 1] + # it's a 2D Tensor and shape == [N+1, N+1]. + self.is_affine.append( + len(shape) == 1 + or (len(shape) == 2 and all([f == (self.ndims + 1) for f in shape])) + ) + + # check sizes + if self.is_affine[i] and len(shape) == 1: + ex = self.ndims * (self.ndims + 1) + if shape[0] != ex: + raise Exception( + "Expected flattened affine of len %d but got %d" + % (ex, shape[0]) + ) + + if not self.is_affine[i]: + if shape[-1] != self.ndims: + raise Exception( + "Offset flow field size expected: %d, found: %d" + % (self.ndims, shape[-1]) + ) + + # confirm built + self.built = True + + def call(self, inputs, **kwargs): + """ + Parameters + inputs: list with several entries: the volume followed by the transforms + """ + + # check shapes + assert 1 < len(inputs) < 4, "inputs has to be len 2 or 3, found: %d" % len( + inputs + ) + vol = inputs[0] + trf = inputs[1:] + + # necessary for multi_gpu models... + vol = K.reshape(vol, [-1, *self.inshape[0][1:]]) + for i in range(len(trf)): + trf[i] = K.reshape(trf[i], [-1, *self.inshape[i + 1][1:]]) + + # reorder transforms, non-linear first and affine second + ind_nonlinear_linear = [ + i[0] for i in sorted(enumerate(self.is_affine), key=lambda x: x[1]) + ] + self.is_affine = [self.is_affine[i] for i in ind_nonlinear_linear] + self.inshape = [self.inshape[i] for i in ind_nonlinear_linear] + trf = [trf[i] for i in ind_nonlinear_linear] + + # go from affine to deformation field + if len(trf) == 1: + trf = trf[0] + if self.is_affine[0]: + trf = tf.map_fn( + lambda x: self._single_aff_to_shift(x, vol.shape[1:-1]), + trf, + dtype=tf.float32, + ) + # combine non-linear and affine to obtain a single deformation field + elif len(trf) == 2: + trf = tf.map_fn( + lambda x: self._non_linear_and_aff_to_shift(x, vol.shape[1:-1]), + trf, + dtype=tf.float32, + ) + + # prepare location shift + if self.indexing == "xy": # shift the first two dimensions + trf_split = tf.split(trf, trf.shape[-1], axis=-1) + trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]] + trf = tf.concat(trf_lst, -1) + + # map transform across batch + if self.single_transform: + return tf.map_fn(self._single_transform, [vol, trf[0, :]], dtype=tf.float32) + else: + return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32) + + def _single_aff_to_shift(self, trf, volshape): + if len(trf.shape) == 1: # go from vector to matrix + trf = tf.reshape(trf, [self.ndims, self.ndims + 1]) + return affine_to_shift(trf, volshape, shift_center=True) + + def _non_linear_and_aff_to_shift(self, trf, volshape): + if len(trf[1].shape) == 1: # go from vector to matrix + trf[1] = tf.reshape(trf[1], [self.ndims, self.ndims + 1]) + return combine_non_linear_and_aff_to_shift(trf, volshape, shift_center=True) + + def _single_transform(self, inputs): + return transform(inputs[0], inputs[1], interp_method=self.interp_method) + + +class VecInt(Layer): + """ + Vector Integration Layer + + Enables vector integration via several methods + (ode or quadrature for time-dependent vector fields, + scaling and squaring for stationary fields) + + If you find this function useful, please cite: + Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration + Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu + MICCAI 2018. + """ + + def __init__( + self, + indexing="ij", + method="ss", + int_steps=7, + out_time_pt=1, + ode_args=None, + odeint_fn=None, + **kwargs + ): + """ + Parameters: + method can be any of the methods in neuron.utils.integrate_vec + indexing can be 'xy' (switches first two dimensions) or 'ij' + int_steps is the number of integration steps + out_time_pt is time point at which to output if using odeint integration + """ + + assert indexing in [ + "ij", + "xy", + ], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" + self.indexing = indexing + self.method = method + self.int_steps = int_steps + self.inshape = None + self.out_time_pt = out_time_pt + self.odeint_fn = odeint_fn # if none then will use a tensorflow function + self.ode_args = ode_args + if ode_args is None: + self.ode_args = {"rtol": 1e-6, "atol": 1e-12} + super(self.__class__, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["indexing"] = self.indexing + config["method"] = self.method + config["int_steps"] = self.int_steps + config["out_time_pt"] = self.out_time_pt + config["ode_args"] = self.ode_args + config["odeint_fn"] = self.odeint_fn + return config + + def build(self, input_shape): + # confirm built + self.built = True + + trf_shape = input_shape + if isinstance(input_shape[0], (list, tuple)): + trf_shape = input_shape[0] + self.inshape = trf_shape + + if trf_shape[-1] != len(trf_shape) - 2: + raise Exception( + "transform ndims %d does not match expected ndims %d" + % (trf_shape[-1], len(trf_shape) - 2) + ) + + def call(self, inputs, **kwargs): + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + loc_shift = inputs[0] + + # necessary for multi_gpu models... + loc_shift = K.reshape(loc_shift, [-1, *self.inshape[1:]]) + + # prepare location shift + if self.indexing == "xy": # shift the first two dimensions + loc_shift_split = tf.split(loc_shift, loc_shift.shape[-1], axis=-1) + loc_shift_lst = [ + loc_shift_split[1], + loc_shift_split[0], + *loc_shift_split[2:], + ] + loc_shift = tf.concat(loc_shift_lst, -1) + + if len(inputs) > 1: + assert ( + self.out_time_pt is None + ), "out_time_pt should be None if providing batch_based out_time_pt" + + # map transform across batch + out = tf.map_fn(self._single_int, [loc_shift] + inputs[1:], dtype=tf.float32) + return out + + def _single_int(self, inputs): + + vel = inputs[0] + out_time_pt = self.out_time_pt + if len(inputs) == 2: + out_time_pt = inputs[1] + return integrate_vec( + vel, + method=self.method, + nb_steps=self.int_steps, + ode_args=self.ode_args, + out_time_pt=out_time_pt, + odeint_fn=self.odeint_fn, + ) + + +class Resize(Layer): + """ + N-D Resize Tensorflow / Keras Layer + Note: this is not re-shaping an existing volume, but resizing, like scipy's "Zoom" + + If you find this function useful, please cite: + Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation,Dalca AV, Guttag J, Sabuncu MR + CVPR 2018 + + Since then, we've re-written the code to be generalized to any + dimensions, and along the way wrote grid and interpolation functions + """ + + def __init__(self, zoom_factor=None, size=None, interp_method="linear", **kwargs): + """ + Parameters: + interp_method: 'linear' or 'nearest' + 'xy' indexing will have the first two entries of the flow + (along last axis) flipped compared to 'ij' indexing + """ + self.zoom_factor = zoom_factor + self.size = list(size) + self.zoom_factor0 = None + self.size0 = None + self.interp_method = interp_method + self.ndims = None + self.inshape = None + super(Resize, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["zoom_factor"] = self.zoom_factor + config["size"] = self.size + config["interp_method"] = self.interp_method + return config + + def build(self, input_shape): + """ + input_shape should be an element of list of one inputs: + input1: volume + should be a *vol_shape x N + """ + + if isinstance(input_shape[0], (list, tuple)) and len(input_shape) > 1: + raise Exception("Resize must be called on a list of length 1.") + + if isinstance(input_shape[0], (list, tuple)): + input_shape = input_shape[0] + + # set up number of dimensions + self.ndims = len(input_shape) - 2 + self.inshape = input_shape + + # check zoom_factor + if isinstance(self.zoom_factor, float): + self.zoom_factor0 = [self.zoom_factor] * self.ndims + elif self.zoom_factor is None: + self.zoom_factor0 = [0] * self.ndims + elif isinstance(self.zoom_factor, (list, tuple)): + self.zoom_factor0 = deepcopy(self.zoom_factor) + assert ( + len(self.zoom_factor0) == self.ndims + ), "zoom factor length {} does not match number of dimensions {}".format( + len(self.zoom_factor), self.ndims + ) + else: + raise Exception( + "zoom_factor should be an int or a list/tuple of int (or None if size is not set to None)" + ) + + # check size + if isinstance(self.size, int): + self.size0 = [self.size] * self.ndims + elif self.size is None: + self.size0 = [0] * self.ndims + elif isinstance(self.size, (list, tuple)): + self.size0 = deepcopy(self.size) + assert ( + len(self.size0) == self.ndims + ), "size length {} does not match number of dimensions {}".format( + len(self.size0), self.ndims + ) + else: + raise Exception( + "size should be an int or a list/tuple of int (or None if zoom_factor is not set to None)" + ) + + # confirm built + self.built = True + + super(Resize, self).build(input_shape) # Be sure to call this somewhere! + + def call(self, inputs, **kwargs): + """ + Parameters + inputs: volume or list of one volume + """ + + # check shapes + if isinstance(inputs, (list, tuple)): + assert len(inputs) == 1, "inputs has to be len 1. found: %d" % len(inputs) + vol = inputs[0] + else: + vol = inputs + + # necessary for multi_gpu models... + vol = K.reshape(vol, [-1, *self.inshape[1:]]) + + # set value of missing size or zoom_factor + if not any(self.zoom_factor0): + self.zoom_factor0 = [ + self.size0[i] / self.inshape[i + 1] for i in range(self.ndims) + ] + else: + self.size0 = [ + int(self.inshape[f + 1] * self.zoom_factor0[f]) + for f in range(self.ndims) + ] + + # map transform across batch + return tf.map_fn(self._single_resize, vol, dtype=vol.dtype) + + def compute_output_shape(self, input_shape): + + output_shape = [input_shape[0]] + output_shape += [ + int(input_shape[1:-1][f] * self.zoom_factor0[f]) for f in range(self.ndims) + ] + output_shape += [input_shape[-1]] + return tuple(output_shape) + + def _single_resize(self, inputs): + return resize( + inputs, self.zoom_factor0, self.size0, interp_method=self.interp_method + ) + + +# Zoom naming of resize, to match scipy's naming +Zoom = Resize + + +######################################################### +# "Local" layers -- layers with parameters at each voxel +######################################################### + + +class LocalBias(Layer): + """ + Local bias layer: each pixel/voxel has its own bias operation (one parameter) + out[v] = in[v] + b + """ + + def __init__(self, my_initializer="RandomNormal", biasmult=1.0, **kwargs): + self.initializer = my_initializer + self.biasmult = biasmult + self.kernel = None + super(LocalBias, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["my_initializer"] = self.initializer + config["biasmult"] = self.biasmult + return config + + def build(self, input_shape): + # Create a trainable weight variable for this layer. + self.kernel = self.add_weight( + name="kernel", + shape=input_shape[1:], + initializer=self.initializer, + trainable=True, + ) + super(LocalBias, self).build(input_shape) # Be sure to call this somewhere! + + def call(self, x, **kwargs): + return x + self.kernel * self.biasmult # weights are difference from input + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/nobrainer/ext/neuron/models.py b/nobrainer/ext/neuron/models.py new file mode 100644 index 00000000..a13a9167 --- /dev/null +++ b/nobrainer/ext/neuron/models.py @@ -0,0 +1,875 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +import sys + +import keras +import keras.backend as K +import keras.layers as KL +from keras.models import Model + +# third party +import numpy as np +import tensorflow as tf + +from nobrainer.ext.neuron import layers + + +def unet( + nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name="unet", + prefix=None, + feat_mult=1, + pool_size=2, + use_logp=True, + padding="same", + dilation_rate_mult=1, + activation="elu", + skip_n_concatenations=0, + use_residuals=False, + final_pred_activation="softmax", + nb_conv_per_level=1, + add_prior_layer=False, + layer_nb_feats=None, + conv_dropout=0, + batch_norm=None, + input_model=None, +): + """ + unet-style keras model with an overdose of parametrization. + + Parameters: + nb_features: the number of features at each convolutional level + see below for `feat_mult` and `layer_nb_feats` for modifiers to this number + input_shape: input layer shape, vector of size ndims + 1 (nb_channels) + conv_size: the convolution kernel size + nb_levels: the number of Unet levels (number of downsamples) in the "encoder" + (e.g. 4 would give you 4 levels in encoder, 4 in decoder) + nb_labels: number of output channels + name (default: 'unet'): the name of the network + prefix (default: `name` value): prefix to be added to layer names + feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels. + e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the + second layer, 64 features in the third layer, etc. + pool_size (default: 2): max pooling size (integer or list if specifying per dimension) + skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n + top levels. + use_logp: + padding: + dilation_rate_mult: + activation: + use_residuals: + final_pred_activation: + nb_conv_per_level: + add_prior_layer: + skip_n_concatenations: + layer_nb_feats: list of the number of features for each layer. Automatically used if specified + conv_dropout: dropout probability + batch_norm: + input_model: concatenate the provided input_model to this current model. + Only the first output of input_model is used. + """ + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # get encoding model + enc_model = conv_enc( + nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + layer_nb_feats=layer_nb_feats, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model, + ) + + # get decoder + # use_skip_connections=True makes it a u-net + lnf = ( + layer_nb_feats[(nb_levels * nb_conv_per_level) :] + if layer_nb_feats is not None + else None + ) + dec_model = conv_dec( + nb_features, + [], + nb_levels, + conv_size, + nb_labels, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=True, + skip_n_concatenations=skip_n_concatenations, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + final_pred_activation="linear" if add_prior_layer else final_pred_activation, + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + layer_nb_feats=lnf, + conv_dropout=conv_dropout, + input_model=enc_model, + ) + final_model = dec_model + + if add_prior_layer: + final_model = add_prior( + dec_model, + [*input_shape[:-1], nb_labels], + name=model_name + "_prior", + use_logp=use_logp, + final_pred_activation=final_pred_activation, + ) + + return final_model + + +def ae( + nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + enc_size, + name="ae", + feat_mult=1, + pool_size=2, + padding="same", + activation="elu", + use_residuals=False, + nb_conv_per_level=1, + batch_norm=None, + enc_batch_norm=None, + ae_type="conv", # 'dense', or 'conv' + enc_lambda_layers=None, + add_prior_layer=False, + use_logp=True, + conv_dropout=0, + include_mu_shift_layer=False, + single_model=False, # whether to return a single model, or a tuple of models that can be stacked. + final_pred_activation="softmax", + do_vae=False, + input_model=None, +): + """Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True).""" + + # naming + model_name = name + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # get encoding model + enc_model = conv_enc( + nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model, + ) + + # middle AE structure + if single_model: + in_input_shape = None + in_model = enc_model + else: + in_input_shape = enc_model.output.shape.as_list()[1:] + in_model = None + mid_ae_model = single_ae( + enc_size, + in_input_shape, + conv_size=conv_size, + name=model_name, + ae_type=ae_type, + input_model=in_model, + batch_norm=enc_batch_norm, + enc_lambda_layers=enc_lambda_layers, + include_mu_shift_layer=include_mu_shift_layer, + do_vae=do_vae, + ) + + # decoder + if single_model: + in_input_shape = None + in_model = mid_ae_model + else: + in_input_shape = mid_ae_model.output.shape.as_list()[1:] + in_model = None + dec_model = conv_dec( + nb_features, + in_input_shape, + nb_levels, + conv_size, + nb_labels, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=False, + padding=padding, + activation=activation, + use_residuals=use_residuals, + final_pred_activation="linear", + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + conv_dropout=conv_dropout, + input_model=in_model, + ) + + if add_prior_layer: + dec_model = add_prior( + dec_model, + [*input_shape[:-1], nb_labels], + name=model_name, + prefix=model_name + "_prior", + use_logp=use_logp, + final_pred_activation=final_pred_activation, + ) + + if single_model: + return dec_model + else: + return dec_model, mid_ae_model, enc_model + + +def conv_enc( + nb_features, + input_shape, + nb_levels, + conv_size, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + dilation_rate_mult=1, + padding="same", + activation="elu", + layer_nb_feats=None, + use_residuals=False, + nb_conv_per_level=2, + conv_dropout=0, + batch_norm=None, + input_model=None, +): + """Fully Convolutional Encoder""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # first layer: input + name = "%s_input" % prefix + if input_model is None: + input_tensor = KL.Input(shape=input_shape, name=name) + last_tensor = input_tensor + else: + input_tensor = input_model.inputs + last_tensor = input_model.outputs + if isinstance(last_tensor, list): + last_tensor = last_tensor[0] + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # prepare layers + convL = getattr(KL, "Conv%dD" % ndims) + conv_kwargs = { + "padding": padding, + "activation": activation, + "data_format": "channels_last", + } + maxpool = getattr(KL, "MaxPooling%dD" % ndims) + + # down arm: + # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers + lfidx = 0 # level feature index + for level in range(nb_levels): + lvl_first_tensor = last_tensor + nb_lvl_feats = np.round(nb_features * feat_mult**level).astype(int) + conv_kwargs["dilation_rate"] = dilation_rate_mult**level + + for conv in range( + nb_conv_per_level + ): # does several conv per level, max pooling applied at the end + if layer_nb_feats is not None: # None or List of all the feature numbers + nb_lvl_feats = layer_nb_feats[lfidx] + lfidx += 1 + + name = "%s_conv_downarm_%d_%d" % (prefix, level, conv) + if conv < (nb_conv_per_level - 1) or (not use_residuals): + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)( + last_tensor + ) + else: # no activation + last_tensor = convL( + nb_lvl_feats, conv_size, padding=padding, name=name + )(last_tensor) + + if conv_dropout > 0: + # conv dropout along feature space only + name = "%s_dropout_downarm_%d_%d" % (prefix, level, conv) + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout( + conv_dropout, noise_shape=noise_shape, name=name + )(last_tensor) + + if use_residuals: + convarm_layer = last_tensor + + # the "add" layer is the original input + # However, it may not have the right number of features to be added + nb_feats_in = lvl_first_tensor.get_shape()[-1] + nb_feats_out = convarm_layer.get_shape()[-1] + add_layer = lvl_first_tensor + if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): + name = "%s_expand_down_merge_%d" % (prefix, level) + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)( + lvl_first_tensor + ) + add_layer = last_tensor + + if conv_dropout > 0: + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)( + last_tensor + ) + + name = "%s_res_down_merge_%d" % (prefix, level) + last_tensor = KL.add([add_layer, convarm_layer], name=name) + + name = "%s_res_down_merge_act_%d" % (prefix, level) + last_tensor = KL.Activation(activation, name=name)(last_tensor) + + if batch_norm is not None: + name = "%s_bn_down_%d" % (prefix, level) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # max pool if we're not at the last level + if level < (nb_levels - 1): + name = "%s_maxpool_%d" % (prefix, level) + last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)( + last_tensor + ) + + # create the model and return + model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) + return model + + +def conv_dec( + nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + use_skip_connections=False, + skip_n_concatenations=0, + padding="same", + dilation_rate_mult=1, + activation="elu", + use_residuals=False, + final_pred_activation="softmax", + nb_conv_per_level=2, + layer_nb_feats=None, + batch_norm=None, + conv_dropout=0, + input_model=None, +): + """Fully Convolutional Decoder""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # if using skip connections, make sure need to use them. + if use_skip_connections: + assert ( + input_model is not None + ), "is using skip connections, tensors dictionary is required" + + # first layer: input + input_name = "%s_input" % prefix + if input_model is None: + input_tensor = KL.Input(shape=input_shape, name=input_name) + last_tensor = input_tensor + else: + input_tensor = input_model.input + last_tensor = input_model.output + input_shape = last_tensor.shape.as_list()[1:] + + # vol size info + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + if ndims > 1: + pool_size = (pool_size,) * ndims + + # prepare layers + convL = getattr(KL, "Conv%dD" % ndims) + conv_kwargs = {"padding": padding, "activation": activation} + upsample = getattr(KL, "UpSampling%dD" % ndims) + + # up arm: + # nb_levels - 1 layers of Deconvolution3D + # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu + lfidx = 0 + for level in range(nb_levels - 1): + nb_lvl_feats = np.round( + nb_features * feat_mult ** (nb_levels - 2 - level) + ).astype(int) + conv_kwargs["dilation_rate"] = dilation_rate_mult ** (nb_levels - 2 - level) + + # upsample matching the max pooling layers size + name = "%s_up_%d" % (prefix, nb_levels + level) + last_tensor = upsample(size=pool_size, name=name)(last_tensor) + up_tensor = last_tensor + + # merge layers combining previous layer + if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)): + conv_name = "%s_conv_downarm_%d_%d" % ( + prefix, + nb_levels - 2 - level, + nb_conv_per_level - 1, + ) + cat_tensor = input_model.get_layer(conv_name).output + name = "%s_merge_%d" % (prefix, nb_levels + level) + last_tensor = KL.concatenate( + [cat_tensor, last_tensor], axis=ndims + 1, name=name + ) + + # convolution layers + for conv in range(nb_conv_per_level): + if layer_nb_feats is not None: + nb_lvl_feats = layer_nb_feats[lfidx] + lfidx += 1 + + name = "%s_conv_uparm_%d_%d" % (prefix, nb_levels + level, conv) + if conv < (nb_conv_per_level - 1) or (not use_residuals): + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)( + last_tensor + ) + else: + last_tensor = convL( + nb_lvl_feats, conv_size, padding=padding, name=name + )(last_tensor) + + if conv_dropout > 0: + name = "%s_dropout_uparm_%d_%d" % (prefix, level, conv) + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout( + conv_dropout, noise_shape=noise_shape, name=name + )(last_tensor) + + # residual block + if use_residuals: + + # the "add" layer is the original input + # However, it may not have the right number of features to be added + add_layer = up_tensor + nb_feats_in = add_layer.get_shape()[-1] + nb_feats_out = last_tensor.get_shape()[-1] + if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): + name = "%s_expand_up_merge_%d" % (prefix, level) + add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)( + add_layer + ) + + if conv_dropout > 0: + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)( + last_tensor + ) + + name = "%s_res_up_merge_%d" % (prefix, level) + last_tensor = KL.add([last_tensor, add_layer], name=name) + + name = "%s_res_up_merge_act_%d" % (prefix, level) + last_tensor = KL.Activation(activation, name=name)(last_tensor) + + if batch_norm is not None: + name = "%s_bn_up_%d" % (prefix, level) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # Compute likelihood prediction (no activation yet) + name = "%s_likelihood" % prefix + last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor) + like_tensor = last_tensor + + # output prediction layer + # we use a softmax to compute P(L_x|I) where x is each location + if final_pred_activation == "softmax": + name = "%s_prediction" % prefix + softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1) + pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor) + + # otherwise create a layer that does nothing. + else: + name = "%s_prediction" % prefix + pred_tensor = KL.Activation("linear", name=name)(like_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name) + return model + + +def add_prior( + input_model, + prior_shape, + name="prior_model", + prefix=None, + use_logp=True, + final_pred_activation="softmax", +): + """ + Append post-prior layer to a given model + """ + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # prior input layer + prior_input_name = "%s-input" % prefix + prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name) + prior_tensor_input = prior_tensor + like_tensor = input_model.output + + # operation varies depending on whether we log() prior or not. + if use_logp: + print( + "Breaking change: use_logp option now requires log input!", file=sys.stderr + ) + merge_op = KL.add + + else: + # using sigmoid to get the likelihood values between 0 and 1 + # note: they won't add up to 1. + name = "%s_likelihood_sigmoid" % prefix + like_tensor = KL.Activation("sigmoid", name=name)(like_tensor) + merge_op = KL.multiply + + # merge the likelihood and prior layers into posterior layer + name = "%s_posterior" % prefix + post_tensor = merge_op([prior_tensor, like_tensor], name=name) + + # output prediction layer + # we use a softmax to compute P(L_x|I) where x is each location + pred_name = "%s_prediction" % prefix + if final_pred_activation == "softmax": + assert use_logp, "cannot do softmax when adding prior via P()" + print( + "using final_pred_activation %s for %s" + % (final_pred_activation, model_name) + ) + softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1) + pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor) + + else: + pred_tensor = KL.Activation("linear", name=pred_name)(post_tensor) + + # create the model + model_inputs = [*input_model.inputs, prior_tensor_input] + model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name) + + # compile + return model + + +def single_ae( + enc_size, + input_shape, + name="single_ae", + prefix=None, + ae_type="dense", # 'dense', or 'conv' + conv_size=None, + input_model=None, + enc_lambda_layers=None, + batch_norm=True, + padding="same", + activation=None, + include_mu_shift_layer=False, + do_vae=False, +): + """single-layer Autoencoder (i.e. input - encoding - output""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + if enc_lambda_layers is None: + enc_lambda_layers = [] + + # prepare input + input_name = "%s_input" % prefix + if input_model is None: + assert input_shape is not None, "input_shape of input_model is necessary" + input_tensor = KL.Input(shape=input_shape, name=input_name) + last_tensor = input_tensor + else: + input_tensor = input_model.input + last_tensor = input_model.output + input_shape = last_tensor.shape.as_list()[1:] + input_nb_feats = last_tensor.shape.as_list()[-1] + + # prepare conv type based on input + ndims = len(input_shape) - 1 + if ae_type == "conv": + convL = getattr(KL, "Conv%dD" % ndims) + assert conv_size is not None, "with conv ae, need conv_size" + conv_kwargs = {"padding": padding, "activation": activation} + enc_size_str = None + + # if want to go through a dense layer in the middle of the U, need to: + # - flatten last layer if not flat + # - do dense encoding and decoding + # - unflatten (reshape spatially) at end + else: # ae_type == 'dense' + if len(input_shape) > 1: + name = "%s_ae_%s_down_flat" % (prefix, ae_type) + last_tensor = KL.Flatten(name=name)(last_tensor) + convL = conv_kwargs = None + assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer" + enc_size_str = "".join(["%d_" % d for d in enc_size])[:-1] + + # recall this layer + pre_enc_layer = last_tensor + + # encoding layer + if ae_type == "dense": + name = "%s_ae_mu_enc_dense_%s" % (prefix, enc_size_str) + last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) + + else: # convolution + + # convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats] + assert len(enc_size) == len( + input_shape + ), "encoding size does not match input shape %d %d" % ( + len(enc_size), + len(input_shape), + ) + + if ( + list(enc_size)[:-1] != list(input_shape)[:-1] + and all([f is not None for f in input_shape[:-1]]) + and all([f is not None for f in enc_size[:-1]]) + ): + + name = "%s_ae_mu_enc_conv" % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer + ) + + name = "%s_ae_mu_enc" % prefix + zf = [ + enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] + for f in range(len(enc_size) - 1) + ] + last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) + + elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck + name = "%s_ae_mu_enc" % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) + + else: + name = "%s_ae_mu_enc" % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer + ) + + if include_mu_shift_layer: + # shift + name = "%s_ae_mu_shift" % prefix + last_tensor = layers.LocalBias(name=name)(last_tensor) + + # encoding clean-up layers + for layer_fcn in enc_lambda_layers: + lambda_name = layer_fcn.__name__ + name = "%s_ae_mu_%s" % (prefix, lambda_name) + last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) + + if batch_norm is not None: + name = "%s_ae_mu_bn" % prefix + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # have a simple layer that does nothing to have a clear name before sampling + name = "%s_ae_mu" % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) + + # if doing variational AE, will need the sigma layer as well. + if do_vae: + mu_tensor = last_tensor + + # encoding layer + if ae_type == "dense": + name = "%s_ae_sigma_enc_dense_%s" % (prefix, enc_size_str) + last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) + + else: + if ( + list(enc_size)[:-1] != list(input_shape)[:-1] + and all([f is not None for f in input_shape[:-1]]) + and all([f is not None for f in enc_size[:-1]]) + ): + + assert ( + len(enc_size) - 1 == 2 + ), "Sorry, I have not yet implemented non-2D resizing..." + name = "%s_ae_sigma_enc_conv" % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer + ) + + name = "%s_ae_sigma_enc" % prefix + resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1]) + last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) + + elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck + name = "%s_ae_sigma_enc" % prefix + last_tensor = convL( + pre_enc_layer.shape.as_list()[-1], + conv_size, + name=name, + **conv_kwargs + )(pre_enc_layer) + # cannot use lambda, then mu and sigma will be same layer. + # last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) + + else: + name = "%s_ae_sigma_enc" % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer + ) + + # encoding clean-up layers + for layer_fcn in enc_lambda_layers: + lambda_name = layer_fcn.__name__ + name = "%s_ae_sigma_%s" % (prefix, lambda_name) + last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) + + if batch_norm is not None: + name = "%s_ae_sigma_bn" % prefix + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # have a simple layer that does nothing to have a clear name before sampling + name = "%s_ae_sigma" % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) + + logvar_tensor = last_tensor + + # VAE sampling + sampler = _VAESample().sample_z + + name = "%s_ae_sample" % prefix + last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor]) + + if include_mu_shift_layer: + # shift + name = "%s_ae_sample_shift" % prefix + last_tensor = layers.LocalBias(name=name)(last_tensor) + + # decoding layer + if ae_type == "dense": + name = "%s_ae_%s_dec_flat_%s" % (prefix, ae_type, enc_size_str) + last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor) + + # unflatten if dense method + if len(input_shape) > 1: + name = "%s_ae_%s_dec" % (prefix, ae_type) + last_tensor = KL.Reshape(input_shape, name=name)(last_tensor) + + else: + + if ( + list(enc_size)[:-1] != list(input_shape)[:-1] + and all([f is not None for f in input_shape[:-1]]) + and all([f is not None for f in enc_size[:-1]]) + ): + name = "%s_ae_mu_dec" % prefix + zf = [ + last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] + for f in range(len(enc_size) - 1) + ] + last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) + + name = "%s_ae_%s_dec" % (prefix, ae_type) + last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)( + last_tensor + ) + + if batch_norm is not None: + name = "%s_bn_ae_%s_dec" % (prefix, ae_type) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) + return model + + +############################################################################### +# Helper function +############################################################################### + + +class _VAESample: + def __init__(self): + pass + + def sample_z(self, args): + mu, log_var = args + shape = K.shape(mu) + eps = K.random_normal(shape=shape, mean=0.0, stddev=1.0) + return mu + K.exp(log_var / 2) * eps diff --git a/nobrainer/ext/neuron/utils.py b/nobrainer/ext/neuron/utils.py new file mode 100644 index 00000000..c6b94028 --- /dev/null +++ b/nobrainer/ext/neuron/utils.py @@ -0,0 +1,593 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +or for the transformation/interpolation related functions: + +Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration +Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu +MICCAI 2018. + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +import itertools + +import keras.backend as K +import numpy as np +import tensorflow as tf + + +def interpn(vol, loc, interp_method="linear"): + """ + N-D gridded interpolation in tensorflow + + vol can have more dimensions than loc[i], in which case loc[i] acts as a slice + for the first dimensions + + Parameters: + vol: volume with size vol_shape or [*vol_shape, nb_features] + loc: an N-long list of N-D Tensors (the interpolation locations) for the new grid + each tensor has to have the same size (but not nec. same size as vol) + or a tensor of size [*new_vol_shape, D] + interp_method: interpolation type 'linear' (default) or 'nearest' + + Returns: + new interpolated volume of the same size as the entries in loc + """ + + if isinstance(loc, (list, tuple)): + loc = tf.stack(loc, -1) + nb_dims = loc.shape[-1] + + if len(vol.shape) not in [nb_dims, nb_dims + 1]: + raise Exception( + "Number of loc Tensors %d does not match volume dimension %d" + % (nb_dims, len(vol.shape[:-1])) + ) + + if nb_dims > len(vol.shape): + raise Exception( + "Loc dimension %d does not match volume dimension %d" + % (nb_dims, len(vol.shape)) + ) + + if len(vol.shape) == nb_dims: + vol = K.expand_dims(vol, -1) + + # flatten and float location Tensors + loc = tf.cast(loc, "float32") + + if isinstance(vol.shape, tf.TensorShape): + volshape = vol.shape.as_list() + else: + volshape = vol.shape + + # interpolate + if interp_method == "linear": + loc0 = tf.floor(loc) + + # clip values + max_loc = [d - 1 for d in vol.get_shape().as_list()] + clipped_loc = [ + tf.clip_by_value(loc[..., d], 0, max_loc[d]) for d in range(nb_dims) + ] + loc0lst = [ + tf.clip_by_value(loc0[..., d], 0, max_loc[d]) for d in range(nb_dims) + ] + + # get other end of point cube + loc1 = [tf.clip_by_value(loc0lst[d] + 1, 0, max_loc[d]) for d in range(nb_dims)] + locs = [ + [tf.cast(f, "int32") for f in loc0lst], + [tf.cast(f, "int32") for f in loc1], + ] + + # compute the difference between the upper value and the original value + # differences are basically 1 - (pt - floor(pt)) + # because: floor(pt) + 1 - pt = 1 + (floor(pt) - pt) = 1 - (pt - floor(pt)) + diff_loc1 = [loc1[d] - clipped_loc[d] for d in range(nb_dims)] + diff_loc0 = [1 - d for d in diff_loc1] + weights_loc = [ + diff_loc1, + diff_loc0, + ] # note reverse ordering since weights are inverse of diff. + + # go through all the cube corners, indexed by a ND binary vector + # e.g. [0, 0] means this "first" corner in a 2-D "cube" + cube_pts = list(itertools.product([0, 1], repeat=nb_dims)) + interp_vol = 0 + + for c in cube_pts: + # get nd values + # note re: indices above volumes via https://github.com/tensorflow/tensorflow/issues/15091 + # It works on GPU because we do not perform index validation checking on GPU -- it's too + # expensive. Instead we fill the output with zero for the corresponding value. The CPU + # version caught the bad index and returned the appropriate error. + subs = [locs[c[d]][d] for d in range(nb_dims)] + + idx = sub2ind(vol.shape[:-1], subs) + vol_val = tf.gather(tf.reshape(vol, [-1, volshape[-1]]), idx) + + # get the weight of this cube_pt based on the distance + # if c[d] is 0 --> want weight = 1 - (pt - floor[pt]) = diff_loc1 + # if c[d] is 1 --> want weight = pt - floor[pt] = diff_loc0 + wts_lst = [weights_loc[c[d]][d] for d in range(nb_dims)] + wt = prod_n(wts_lst) + wt = K.expand_dims(wt, -1) + + # compute final weighted value for each cube corner + interp_vol += wt * vol_val + + else: + assert interp_method == "nearest" + roundloc = tf.cast(tf.round(loc), "int32") + + # clip values + max_loc = [tf.cast(d - 1, "int32") for d in vol.shape] + roundloc = [ + tf.clip_by_value(roundloc[..., d], 0, max_loc[d]) for d in range(nb_dims) + ] + + # get values + idx = sub2ind(vol.shape[:-1], roundloc) + interp_vol = tf.gather(tf.reshape(vol, [-1, vol.shape[-1]]), idx) + + return interp_vol + + +def resize(vol, zoom_factor, new_shape, interp_method="linear"): + """ + if zoom_factor is a list, it will determine the ndims, in which case vol has to be of length ndims or ndims + 1 + + if zoom_factor is an integer, then vol must be of length ndims + 1 + + new_shape should be a list of length ndims + + """ + + if isinstance(zoom_factor, (list, tuple)): + ndims = len(zoom_factor) + vol_shape = vol.shape[:ndims] + assert len(vol_shape) in ( + ndims, + ndims + 1, + ), "zoom_factor length %d does not match ndims %d" % (len(vol_shape), ndims) + else: + vol_shape = vol.shape[:-1] + ndims = len(vol_shape) + zoom_factor = [zoom_factor] * ndims + + # get grid for new shape + grid = volshape_to_ndgrid(new_shape) + grid = [tf.cast(f, "float32") for f in grid] + offset = [grid[f] / zoom_factor[f] - grid[f] for f in range(ndims)] + offset = tf.stack(offset, ndims) + + # transform + return transform(vol, offset, interp_method) + + +zoom = resize + + +def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing="ij"): + """ + transform an affine matrix to a dense location shift tensor in tensorflow + + Algorithm: + - get grid and shift grid to be centered at the center of the image (optionally) + - apply affine matrix to each index. + - subtract grid + + Parameters: + affine_matrix: ND+1 x ND+1 or ND x ND+1 matrix (Tensor) + volshape: 1xN Nd Tensor of the size of the volume. + shift_center (optional) + indexing + + Returns: + shift field (Tensor) of size *volshape x N + """ + + if isinstance(volshape, tf.TensorShape): + volshape = volshape.as_list() + + if affine_matrix.dtype != "float32": + affine_matrix = tf.cast(affine_matrix, "float32") + + nb_dims = len(volshape) + + if len(affine_matrix.shape) == 1: + if len(affine_matrix) != (nb_dims * (nb_dims + 1)): + raise ValueError( + "transform is supposed a vector of len ndims * (ndims + 1)." + "Got len %d" % len(affine_matrix) + ) + + affine_matrix = tf.reshape(affine_matrix, [nb_dims, nb_dims + 1]) + + if not ( + affine_matrix.shape[0] in [nb_dims, nb_dims + 1] + and affine_matrix.shape[1] == (nb_dims + 1) + ): + raise Exception( + "Affine matrix shape should match" + "%d+1 x %d+1 or " % (nb_dims, nb_dims) + + "%d x %d+1." % (nb_dims, nb_dims) + + "Got: " + + str(volshape) + ) + + # list of volume ndgrid + # N-long list, each entry of shape volshape + mesh = volshape_to_meshgrid(volshape, indexing=indexing) + mesh = [tf.cast(f, "float32") for f in mesh] + + if shift_center: + mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] + + # add an all-ones entry and transform into a large matrix + flat_mesh = [flatten(f) for f in mesh] + flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype="float32")) + mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # 4 x nb_voxels + + # compute locations + loc_matrix = tf.matmul(affine_matrix, mesh_matrix) # N+1 x nb_voxels + loc_matrix = tf.transpose(loc_matrix[:nb_dims, :]) # nb_voxels x N + loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims]) # *volshape x N + + # get shifts and return + return loc - tf.stack(mesh, axis=nb_dims) + + +def combine_non_linear_and_aff_to_shift( + transform_list, volshape, shift_center=True, indexing="ij" +): + """ + transform an affine matrix to a dense location shift tensor in tensorflow + + Algorithm: + - get grid and shift grid to be centered at the center of the image (optionally) + - apply affine matrix to each index. + - subtract grid + + Parameters: + transform_list: list of non-linear tensor (size of volshape) and affine ND+1 x ND+1 or ND x ND+1 tensor + volshape: 1xN Nd Tensor of the size of the volume. + shift_center (optional) + indexing + + Returns: + shift field (Tensor) of size *volshape x N + """ + + if isinstance(volshape, tf.TensorShape): + volshape = volshape.as_list() + + # convert transforms to floats + for i in range(len(transform_list)): + if transform_list[i].dtype != "float32": + transform_list[i] = tf.cast(transform_list[i], "float32") + + nb_dims = len(volshape) + + # transform affine to matrix if given as vector + if len(transform_list[1].shape) == 1: + if len(transform_list[1]) != (nb_dims * (nb_dims + 1)): + raise ValueError( + "transform is supposed a vector of len ndims * (ndims + 1)." + "Got len %d" % len(transform_list[1]) + ) + + transform_list[1] = tf.reshape(transform_list[1], [nb_dims, nb_dims + 1]) + + if not ( + transform_list[1].shape[0] in [nb_dims, nb_dims + 1] + and transform_list[1].shape[1] == (nb_dims + 1) + ): + raise Exception( + "Affine matrix shape should match" + "%d+1 x %d+1 or " % (nb_dims, nb_dims) + + "%d x %d+1." % (nb_dims, nb_dims) + + "Got: " + + str(volshape) + ) + + # list of volume ndgrid + # N-long list, each entry of shape volshape + mesh = volshape_to_meshgrid(volshape, indexing=indexing) + mesh = [tf.cast(f, "float32") for f in mesh] + + if shift_center: + mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] + + # add an all-ones entry and transform into a large matrix + # non_linear_mesh = tf.unstack(transform_list[0], axis=3) + non_linear_mesh = tf.unstack(transform_list[0], axis=-1) + flat_mesh = [flatten(mesh[i] + non_linear_mesh[i]) for i in range(len(mesh))] + flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype="float32")) + mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # N+1 x nb_voxels + + # compute locations + loc_matrix = tf.matmul(transform_list[1], mesh_matrix) # N+1 x nb_voxels + loc_matrix = tf.transpose(loc_matrix[:nb_dims, :]) # nb_voxels x N + loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims]) # *volshape x N + + # get shifts and return + return loc - tf.stack(mesh, axis=nb_dims) + + +def transform(vol, loc_shift, interp_method="linear", indexing="ij"): + """ + transform interpolation N-D volumes (features) given shifts at each location in tensorflow + + Essentially interpolates volume vol at locations determined by loc_shift. + This is a spatial transform in the sense that at location [x] we now have the data from, + [x + shift] so we've moved data. + + Parameters: + vol: volume with size vol_shape or [*vol_shape, nb_features] + loc_shift: shift volume [*new_vol_shape, N] + interp_method (default:'linear'): 'linear', 'nearest' + indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian). + In general, prefer to leave this 'ij' + + Return: + new interpolated volumes in the same size as loc_shift[0] + """ + + # parse shapes + if isinstance(loc_shift.shape, tf.TensorShape): + volshape = loc_shift.shape[:-1].as_list() + else: + volshape = loc_shift.shape[:-1] + nb_dims = len(volshape) + + # location should be meshed and delta + mesh = volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh + loc = [tf.cast(mesh[d], "float32") + loc_shift[..., d] for d in range(nb_dims)] + + # test single + return interpn(vol, loc, interp_method=interp_method) + + +def integrate_vec(vec, time_dep=False, method="ss", **kwargs): + """ + Integrate (stationary of time-dependent) vector field (N-D Tensor) in tensorflow + + Aside from directly using tensorflow's numerical integration odeint(), also implements + "scaling and squaring", and quadrature. Note that the diff. equation given to odeint + is the one used in quadrature. + + Parameters: + vec: the Tensor field to integrate. + If vol_size is the size of the intrinsic volume, and vol_ndim = len(vol_size), + then vector shape (vec_shape) should be + [vol_size, vol_ndim] (if stationary) + [vol_size, vol_ndim, nb_time_steps] (if time dependent) + time_dep: bool whether vector is time dependent + method: 'scaling_and_squaring' or 'ss' or 'quadrature' + + if using 'scaling_and_squaring': currently only supports integrating to time point 1. + nb_steps int number of steps. Note that this means the vec field gets broken own to 2**nb_steps. + so nb_steps of 0 means integral = vec. + + Returns: + int_vec: integral of vector field with same shape as the input + """ + + if method not in ["ss", "scaling_and_squaring", "ode", "quadrature"]: + raise ValueError( + "method has to be 'scaling_and_squaring' or 'ode'. found: %s" % method + ) + + if method in ["ss", "scaling_and_squaring"]: + nb_steps = kwargs["nb_steps"] + assert nb_steps >= 0, "nb_steps should be >= 0, found: %d" % nb_steps + + if time_dep: + svec = K.permute_dimensions(vec, [-1, *range(0, vec.shape[-1] - 1)]) + assert ( + 2**nb_steps == svec.shape[0] + ), "2**nb_steps and vector shape don't match" + + svec = svec / (2**nb_steps) + for _ in range(nb_steps): + svec = svec[0::2] + tf.map_fn(transform, svec[1::2, :], svec[0::2, :]) + + disp = svec[0, :] + + else: + vec = vec / (2**nb_steps) + for _ in range(nb_steps): + vec += transform(vec, vec) + disp = vec + + else: # method == 'quadrature': + nb_steps = kwargs["nb_steps"] + assert nb_steps >= 1, "nb_steps should be >= 1, found: %d" % nb_steps + + vec = vec / nb_steps + + if time_dep: + disp = vec[..., 0] + for si in range(nb_steps - 1): + disp += transform(vec[..., si + 1], disp) + else: + disp = vec + for _ in range(nb_steps - 1): + disp += transform(vec, disp) + + return disp + + +def volshape_to_ndgrid(volshape, **kwargs): + """ + compute Tensor ndgrid from a volume size + + Parameters: + volshape: the volume size + + Returns: + A list of Tensors + + See Also: + ndgrid + """ + + isint = [float(d).is_integer() for d in volshape] + if not all(isint): + raise ValueError("volshape needs to be a list of integers") + + linvec = [tf.range(0, d) for d in volshape] + return ndgrid(*linvec, **kwargs) + + +def volshape_to_meshgrid(volshape, **kwargs): + """ + compute Tensor meshgrid from a volume size + + Parameters: + volshape: the volume size + + Returns: + A list of Tensors + + See Also: + tf.meshgrid, meshgrid, ndgrid, volshape_to_ndgrid + """ + + isint = [float(d).is_integer() for d in volshape] + if not all(isint): + raise ValueError("volshape needs to be a list of integers") + + linvec = [tf.range(0, d) for d in volshape] + return meshgrid(*linvec, **kwargs) + + +def ndgrid(*args, **kwargs): + """ + broadcast Tensors on an N-D grid with ij indexing + uses meshgrid with ij indexing + + Parameters: + *args: Tensors with rank 1 + **args: "name" (optional) + + Returns: + A list of Tensors + + """ + return meshgrid(*args, indexing="ij", **kwargs) + + +def meshgrid(*args, **kwargs): + """ + + meshgrid code that builds on (copies) tensorflow's meshgrid but dramatically + improves runtime by changing the last step to tiling instead of multiplication. + https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/python/ops/array_ops.py#L1921 + + Broadcasts parameters for evaluation on an N-D grid. + Given N one-dimensional coordinate arrays `*args`, returns a list `outputs` + of N-D coordinate arrays for evaluating expressions on an N-D grid. + Notes: + `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions. + When the `indexing` argument is set to 'xy' (the default), the broadcasting + instructions for the first two dimensions are swapped. + Examples: + Calling `X, Y = meshgrid(x, y)` with the tensors + ```python + x = [1, 2, 3] + y = [4, 5, 6] + X, Y = meshgrid(x, y) + # X = [[1, 2, 3], + # [1, 2, 3], + # [1, 2, 3]] + # Y = [[4, 4, 4], + # [5, 5, 5], + # [6, 6, 6]] + ``` + Args: + *args: `Tensor`s with rank 1. + **kwargs: + - indexing: Either 'xy' or 'ij' (optional, default: 'xy'). + - name: A name for the operation (optional). + Returns: + outputs: A list of N `Tensor`s with rank N. + Raises: + TypeError: When no keyword arguments (kwargs) are passed. + ValueError: When indexing keyword argument is not one of `xy` or `ij`. + """ + + indexing = kwargs.pop("indexing", "xy") + if kwargs: + key = list(kwargs.keys())[0] + raise TypeError( + "'{}' is an invalid keyword argument " "for this function".format(key) + ) + + if indexing not in ("xy", "ij"): + raise ValueError("indexing parameter must be either 'xy' or 'ij'") + + # with ops.name_scope(name, "meshgrid", args) as name: + ndim = len(args) + s0 = (1,) * ndim + + # Prepare reshape by inserting dimensions with size 1 where needed + output = [] + for i, x in enumerate(args): + output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1 : :]))) + # Create parameters for broadcasting each tensor to the full size + shapes = [tf.size(x) for x in args] + sz = [x.get_shape().as_list()[0] for x in args] + + # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype + if indexing == "xy" and ndim > 1: + output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2)) + output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2)) + shapes[0], shapes[1] = shapes[1], shapes[0] + sz[0], sz[1] = sz[1], sz[0] + + for i in range(len(output)): + stack_sz = [*sz[:i], 1, *sz[(i + 1) :]] + if indexing == "xy" and ndim > 1 and i < 2: + stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0] + output[i] = tf.tile(output[i], tf.stack(stack_sz)) + return output + + +def flatten(v): + """flatten Tensor v""" + + return tf.reshape(v, [-1]) + + +def prod_n(lst): + prod = lst[0] + for p in lst[1:]: + prod *= p + return prod + + +def sub2ind(siz, subs): + """assumes column-order major""" + # subs is a list + assert len(siz) == len(subs), "found inconsistent siz and subs: %d %d" % ( + len(siz), + len(subs), + ) + + k = np.cumprod(siz[::-1]) + + ndx = subs[-1] + for i, v in enumerate(subs[:-1][::-1]): + ndx = ndx + v * k[i] + + return ndx diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index a4842bc7..23420f21 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -4,8 +4,10 @@ from .attention_unet_with_inception import attention_unet_with_inception from .autoencoder import autoencoder from .bayesian_meshnet import variational_meshnet +from .bayesian_vnet import bayesian_vnet from .dcgan import dcgan from .highresnet import highresnet +from .labels_to_image_model import labels_to_image_model from .meshnet import meshnet from .progressiveae import progressiveae from .progressivegan import progressivegan @@ -26,6 +28,8 @@ "attention_unet_with_inception": attention_unet_with_inception, "unetr": unetr, "variational_meshnet": variational_meshnet, + "bayesian_vnet": bayesian_vnet, + "synthgenerator": labels_to_image_model, } diff --git a/nobrainer/models/labels_to_image_model.py b/nobrainer/models/labels_to_image_model.py new file mode 100644 index 00000000..4031b297 --- /dev/null +++ b/nobrainer/models/labels_to_image_model.py @@ -0,0 +1,381 @@ +""" +If you use this code, please cite one of the SynthSeg papers: +https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + +import keras.layers as KL +from keras.models import Model + +# python imports +import numpy as np +import tensorflow as tf + +# third-party imports +from nobrainer.ext.lab2im import edit_tensors as l2i_et +from nobrainer.ext.lab2im import layers, utils +from nobrainer.ext.lab2im.edit_volumes import get_ras_axes + + +def labels_to_image_model( + labels_shape, + n_channels, + generation_labels, + output_labels, + n_neutral_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + flipping=True, + aff=None, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3.0, + nonlin_scale=0.0625, + randomise_res=False, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.5, + bias_scale=0.025, + return_gradients=False, +): + """ + This function builds a keras/tensorflow model to generate images from provided label maps. + The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. + The model will take as inputs: + -a label map + -a vector containing the means of the Gaussian Mixture Model for each label, + -a vector containing the standard deviations of the Gaussian Mixture Model for each label, + The model returns: + -the generated image normalised between 0 and 1. + -the corresponding label map, with only the labels present in output_labels (the other are reset to zero). + # IMPORTANT !!! + # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), + # these values refer to the RAS axes. + :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array. + :param n_channels: number of channels to be synthesised. + :param generation_labels: (optional) list of all possible label values in the input label maps. + Default is None, where the label values are directly gotten from the provided label maps. + If not None, can be a sequence or a 1d numpy array. It should be organised as follows: background label first, then + non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same hemisphere (can be left or right), + and finally all the corresponding contralateral structures (in the same order). + :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in the + label maps returned by this model, i.e. all occurrences of generation_labels[i] in the input label maps will be + converted to output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + :param n_neutral_labels: number of non-sided generation labels. + :param atlas_res: resolution of the input label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param target_res: target resolution of the generated images and corresponding label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image + Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. + Default is None, where no cropping is performed. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape + if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. + :param flipping: (optional) whether to introduce right/left random flipping + :param aff: (optional) example of an (n_dims+1)x(n_dims+1) affine matrix of one of the input label map. + Used to find brain's right/left axis. Should be given if flipping is True. + :param scaling_bounds: (optional) range of the random scaling to apply at each mini-batch. The scaling factor for + each dimension is sampled from a uniform distribution of predefined bounds. Can either be: + 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds + [1-scaling_bounds, 1+scaling_bounds] for each dimension. + 2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds + (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension. + 3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution + of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. + 4) False, in which case scaling is completely turned off. + Default is scaling_bounds = 0.2 (case 1) + :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1 + and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]]. + Default is rotation_bounds = 15. + :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. + :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we + encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator). + :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we + sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the elastic + deformation off. + :param nonlin_scale: (optional) if nonlin_std is strictly positive, factor between the shapes of the input + label maps and the shape of the input non-linear tensor. + :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and + 2) resampled to high resolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm]. + In that process, the images generated by sampling the GMM are 1) blurred at the sampled LR, 2) downsampled at LR, + and 3) resampled at target_resolution. + :param max_res_iso: (optional) If randomise_res is True, this enables to control the upper bound of the uniform + distribution from which we sample the random resolution U(min_res, max_res_iso), where min_res is the resolution of + the input label maps. Must be a number, and default is 4. Set to None to deactivate it, but if randomise_res is + True, at least one of max_res_iso or max_res_aniso must be given. + :param max_res_aniso: If randomise_res is True, this enables to downsample the input volumes to a random LR in + only 1 (random) direction. This is done by randomly selecting a direction i in the range [0, n_dims-1], and sampling + a value in the corresponding uniform distribution U(min_res[i], max_res_aniso[i]), where min_res is the resolution + of the input label maps. Can be a number, a sequence, or a 1d numpy array. Set to None to deactivate it, but if + randomise_res is True, at least one of max_res_iso or max_res_aniso must be given. + :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled when + randomise_res is True. This triggers a blurring to mimic the specified acquisition resolution, but the downsampling + is optional (see param downsample). Default for data_res is None, where images are slightly blurred. + If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d + numpy array, or the path to a 1d numpy array. In the multi-modal case, it should be given as a numpy array (or a + path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel. + :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low + resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res. + :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with a + bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full size, + and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the std dev of + the normal distribution from which we sample the first tensor. Set to 0 to deactivate bias field corruption. + :param bias_scale: (optional) If bias_field_std is strictly positive, this designates the ratio between the + size of the input label maps and the size of the first sampled tensor for synthesising the bias field. + :param return_gradients: (optional) whether to return the synthetic image or the magnitude of its spatial gradient + (computed with Sobel kernels). + """ + + # reformat resolutions + labels_shape = utils.reformat_to_list(labels_shape) + n_dims, _ = utils.get_dims(labels_shape) + atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims, n_channels) + data_res = ( + atlas_res + if data_res is None + else utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) + ) + thickness = ( + data_res + if thickness is None + else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + ) + atlas_res = atlas_res[0] + target_res = ( + atlas_res + if target_res is None + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) + + # get shapes + crop_shape, output_shape = get_shapes( + labels_shape, output_shape, atlas_res, target_res, output_div_by_n + ) + + # define model inputs + labels_input = KL.Input( + shape=labels_shape + [1], name="labels_input", dtype="int32" + ) + means_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="means_input" + ) + stds_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="std_devs_input" + ) + list_inputs = [labels_input, means_input, stds_input] + + # deform labels + labels = layers.RandomSpatialDeformation( + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method="nearest", + )(labels_input) + + # cropping + if crop_shape != labels_shape: + labels = layers.RandomCrop(crop_shape)(labels) + + # flipping + if flipping: + assert aff is not None, "aff should not be None if flipping is True" + labels = layers.RandomFlip( + get_ras_axes(aff, n_dims)[0], True, generation_labels, n_neutral_labels + )(labels) + + # build synthetic image + image = layers.SampleConditionalGMM(generation_labels)( + [labels, means_input, stds_input] + ) + + # apply bias field + if bias_field_std > 0: + image = layers.BiasFieldCorruption(bias_field_std, bias_scale, False)(image) + + # intensity augmentation + image = layers.IntensityAugmentation( + clip=300, normalise=True, gamma_std=0.5, separate_channels=True + )(image) + + # loop over channels + channels = list() + split = ( + KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) + if (n_channels > 1) + else [image] + ) + for i, channel in enumerate(split): + + if randomise_res: + max_res_iso = np.array( + utils.reformat_to_list(max_res_iso, length=n_dims, dtype="float") + ) + max_res_aniso = np.array( + utils.reformat_to_list(max_res_aniso, length=n_dims, dtype="float") + ) + max_res = np.maximum(max_res_iso, max_res_aniso) + resolution, blur_res = layers.SampleResolution( + atlas_res, max_res_iso, max_res_aniso + )(means_input) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, resolution, thickness=blur_res + ) + channel = layers.DynamicGaussianBlur( + 0.75 * max_res / np.array(atlas_res), 1.03 + )([channel, sigma]) + channel = layers.MimicAcquisition( + atlas_res, atlas_res, output_shape, False + )([channel, resolution]) + channels.append(channel) + + else: + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, data_res[i], thickness=thickness[i] + ) + channel = layers.GaussianBlur(sigma, 1.03)(channel) + resolution = KL.Lambda( + lambda x: tf.convert_to_tensor(data_res[i], dtype="float32") + )([]) + channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)( + [channel, resolution] + ) + channels.append(channel) + + # concatenate all channels back + image = ( + KL.Lambda(lambda x: tf.concat(x, -1))(channels) + if len(channels) > 1 + else channels[0] + ) + + # compute image gradient + if return_gradients: + image = layers.ImageGradients("sobel", True, name="image_gradients")(image) + image = layers.IntensityAugmentation(clip=10, normalise=True)(image) + + # resample labels at target resolution + if crop_shape != output_shape: + labels = l2i_et.resample_tensor(labels, output_shape, interp_method="nearest") + + # map generation labels to segmentation values + labels = layers.ConvertLabels( + generation_labels, dest_values=output_labels, name="labels_out" + )(labels) + + # build model (dummy layer enables to keep the labels when plugging this model to other models) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) + brain_model = Model(inputs=list_inputs, outputs=[image, labels]) + + return brain_model + + +def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n): + + # reformat resolutions to lists + atlas_res = utils.reformat_to_list(atlas_res) + n_dims = len(atlas_res) + target_res = utils.reformat_to_list(target_res) + + # get resampling factor + if atlas_res != target_res: + resample_factor = [atlas_res[i] / float(target_res[i]) for i in range(n_dims)] + else: + resample_factor = None + + # output shape specified, need to get cropping shape, and resample shape if necessary + if output_shape is not None: + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype="int") + + # make sure that output shape is smaller or equal to label shape + if resample_factor is not None: + output_shape = [ + min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) + for i in range(n_dims) + ] + else: + output_shape = [ + min(labels_shape[i], output_shape[i]) for i in range(n_dims) + ] + + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + tmp_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] + if output_shape != tmp_shape: + print( + "output shape {0} not divisible by {1}, changed to {2}".format( + output_shape, output_div_by_n, tmp_shape + ) + ) + output_shape = tmp_shape + + # get cropping and resample shape + if resample_factor is not None: + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] + else: + cropping_shape = output_shape + + # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n + else: + + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + + # if resampling, get the potential output_shape and check if it is divisible by n + if resample_factor is not None: + output_shape = [ + int(labels_shape[i] * resample_factor[i]) for i in range(n_dims) + ] + output_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] + # if no resampling, simply check if image_shape is divisible by n + else: + cropping_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in labels_shape + ] + output_shape = cropping_shape + + # if no need to be divisible by n, simply take cropping_shape as image_shape, and build output_shape + else: + cropping_shape = labels_shape + if resample_factor is not None: + output_shape = [ + int(cropping_shape[i] * resample_factor[i]) for i in range(n_dims) + ] + else: + output_shape = cropping_shape + + return cropping_shape, output_shape diff --git a/nobrainer/processing/brain_generator.py b/nobrainer/processing/brain_generator.py new file mode 100644 index 00000000..aea63b78 --- /dev/null +++ b/nobrainer/processing/brain_generator.py @@ -0,0 +1,358 @@ +""" +If you use this code, please cite one of the SynthSeg papers: +https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + +# python imports +import numpy as np + +# project imports +from nobrainer.ext.SynthSeg.model_inputs import build_model_inputs + +# third-party imports +from nobrainer.ext.lab2im import edit_volumes, utils +from nobrainer.models.labels_to_image_model import labels_to_image_model + + +class BrainGenerator: + + def __init__( + self, + labels_dir, + generation_labels=None, + n_neutral_labels=None, + output_labels=None, + subjects_prob=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + prior_distributions="uniform", + generation_classes=None, + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, + flipping=True, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=4.0, + nonlin_scale=0.04, + randomise_res=True, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.7, + bias_scale=0.025, + return_gradients=False, + ): + """ + This class is wrapper around the labels_to_image_model model. It contains the GPU model that generates images + from labels maps, and a python generator that supplies the input data for this model. + To generate pairs of image/labels you can just call the method generate_image() on an object of this class. + + :param labels_dir: path of folder with all input label maps, or to a single label map. + + # IMPORTANT !!! + # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), + # these values refer to the RAS axes. + + # label maps-related parameters + :param generation_labels: (optional) list of all possible label values in the input label maps. + Default is None, where the label values are directly gotten from the provided label maps. + If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array. + If flipping is true (i.e. right/left flipping is enabled), generation_labels should be organised as follows: + background label first, then non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same + hemisphere (can be left or right), and finally all the corresponding contralateral structures in the same order. + :param n_neutral_labels: (optional) number of non-sided generation labels. This is important only if you use + flipping augmentation. Default is total number of label values. + :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in + the label maps returned by this function, i.e. all occurrences of generation_labels[i] in the input label maps + will be converted to output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to + pick the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an + array, and it must be as long as path_label_maps. By default, all label maps are chosen with the same importance + + # output-related parameters + :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. + :param n_channels: (optional) number of channels to be synthesised. Default is 1. + :param target_res: (optional) target resolution of the generated images and corresponding label maps. + If None, the outputs will have the same resolution as the input label maps. + Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array. + :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image. + Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. + Default is None, where no cropping is performed. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites + output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or + the path to a 1d numpy array. + + # GMM-sampling parameters + :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity + distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence, a + 1d numpy array, or the path to a 1d numpy array. It should have the same length as generation_labels, and + contain values between 0 and K-1, where K is the total number of classes. + Default is all labels have different classes (K=len(generation_labels)). + :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. + Can either be 'uniform', or 'normal'. Default is 'uniform'. + :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because + these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: + 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is + uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each + mini_batch from the same distribution. + 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is + not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each + mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from + N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. + 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived + from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a + modality from the n_mod possibilities, and we sample the GMM means like in 2). + If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel + (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. + 4) the path to such a numpy array. + Default is None, which corresponds to prior_means = [25, 225]. + :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. + Default is None, which corresponds to prior_stds = [5, 25]. + :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be + only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. + :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default + values for half of these cases, and thus generate images of random contrast. + + # spatial deformation parameters + :param flipping: (optional) whether to introduce right/left random flipping. Default is True. + :param scaling_bounds: (optional) range of the random sampling to apply at each mini-batch. The scaling factor + for each dimension is sampled from a uniform distribution of predefined bounds. Can either be: + 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds + [1-scaling_bounds, 1+scaling_bounds] for each dimension. + 2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds + (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension. + 3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution + of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. + 4) False, in which case scaling is completely turned off. + Default is scaling_bounds = 0.2 (case 1) + :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1 + and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]]. + Default is rotation_bounds = 15. + :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. + :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we + encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator). + :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we + sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the + elastic deformation off. + :param nonlin_scale: (optional) if nonlin_std is strictly positive, factor between the shapes of the + input label maps and the shape of the input non-linear tensor. + + # blurring/resampling parameters + :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and + 2) resampled to high resolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm]. + In that process, the images generated by sampling the GMM are: + 1) blurred at the sampled LR, 2) downsampled at LR, and 3) resampled at target_resolution. + :param max_res_iso: (optional) If randomise_res is True, this enables to control the upper bound of the uniform + distribution from which we sample the random resolution U(min_res, max_res_iso), where min_res is the resolution + of the input label maps. Must be a number, and default is 4. Set to None to deactivate it, but if randomise_res + is True, at least one of max_res_iso or max_res_aniso must be given. + :param max_res_aniso: If randomise_res is True, this enables to downsample the input volumes to a random LR + in only 1 (random) direction. This is done by randomly selecting a direction i in range [0, n_dims-1], and + sampling a value in the corresponding uniform distribution U(min_res[i], max_res_aniso[i]), where min_res is the + resolution of the input label maps. Can be a number, a sequence, or a 1d numpy array. Set to None to deactivate + it, but if randomise_res is True, at least one of max_res_iso or max_res_aniso must be given. + :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled + when randomise_res is True. This triggers a blurring which mimics the acquisition resolution, but downsampling + is optional (see param downsample). Default for data_res is None, where images are slightly blurred. + If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, + a 1d numpy array, or the path to a 1d numpy array. In the multi-modal case, it should be given as a numpy array + (or a path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel. + :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low + resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res. + + # bias field parameters + :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with + a bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full + size, and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the + std dev of the normal distribution from which we sample the first tensor. Set to 0 to deactivate bias field. + :param bias_scale: (optional) If bias_field_std is strictly positive, this designates the ratio between + the size of the input label maps and the size of the first sampled tensor for synthesising the bias field. + + :param return_gradients: (optional) whether to return the synthetic image or the magnitude of its spatial + gradient (computed with Sobel kernels). + """ + + # prepare data files + self.labels_paths = utils.list_images_in_folder(labels_dir) + if subjects_prob is not None: + self.subjects_prob = np.array( + utils.reformat_to_list(subjects_prob, load_as_numpy=True), + dtype="float32", + ) + assert len(self.subjects_prob) == len(self.labels_paths), ( + "subjects_prob should have the same length as labels_path, " + "had {} and {}".format(len(self.subjects_prob), len(self.labels_paths)) + ) + else: + self.subjects_prob = None + + # generation parameters + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = ( + utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + ) + self.n_channels = n_channels + if generation_labels is not None: + self.generation_labels = utils.load_array_if_path(generation_labels) + else: + self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir) + if output_labels is not None: + self.output_labels = utils.load_array_if_path(output_labels) + else: + self.output_labels = self.generation_labels + if n_neutral_labels is not None: + self.n_neutral_labels = n_neutral_labels + else: + self.n_neutral_labels = self.generation_labels.shape[0] + self.target_res = utils.load_array_if_path(target_res) + self.batchsize = batchsize + # preliminary operations + self.flipping = flipping + self.output_shape = utils.load_array_if_path(output_shape) + self.output_div_by_n = output_div_by_n + # GMM parameters + self.prior_distributions = prior_distributions + if generation_classes is not None: + self.generation_classes = utils.load_array_if_path(generation_classes) + assert ( + self.generation_classes.shape == self.generation_labels.shape + ), "if provided, generation_classes should have the same shape as generation_labels" + unique_classes = np.unique(self.generation_classes) + assert np.array_equal( + unique_classes, np.arange(np.max(unique_classes) + 1) + ), "generation_classes should a linear range between 0 and its maximum value." + else: + self.generation_classes = np.arange(self.generation_labels.shape[0]) + self.prior_means = utils.load_array_if_path(prior_means) + self.prior_stds = utils.load_array_if_path(prior_stds) + self.use_specific_stats_for_channel = use_specific_stats_for_channel + # linear transformation parameters + self.scaling_bounds = utils.load_array_if_path(scaling_bounds) + self.rotation_bounds = utils.load_array_if_path(rotation_bounds) + self.shearing_bounds = utils.load_array_if_path(shearing_bounds) + self.translation_bounds = utils.load_array_if_path(translation_bounds) + # elastic transformation parameters + self.nonlin_std = nonlin_std + self.nonlin_scale = nonlin_scale + # blurring parameters + self.randomise_res = randomise_res + self.max_res_iso = max_res_iso + self.max_res_aniso = max_res_aniso + self.data_res = utils.load_array_if_path(data_res) + assert not ( + self.randomise_res & (self.data_res is not None) + ), "randomise_res and data_res cannot be provided at the same time" + self.thickness = utils.load_array_if_path(thickness) + # bias field parameters + self.bias_field_std = bias_field_std + self.bias_scale = bias_scale + self.return_gradients = return_gradients + + # build transformation model + self.labels_to_image_model, self.model_output_shape = ( + self._build_labels_to_image_model() + ) + + # build generator for model inputs + self.model_inputs_generator = self._build_model_inputs_generator( + mix_prior_and_random + ) + + # build brain generator + self.brain_generator = self._build_brain_generator() + + def _build_labels_to_image_model(self): + # build_model + lab_to_im_model = labels_to_image_model( + labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + n_neutral_labels=self.n_neutral_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + flipping=self.flipping, + aff=np.eye(4), + scaling_bounds=self.scaling_bounds, + rotation_bounds=self.rotation_bounds, + shearing_bounds=self.shearing_bounds, + translation_bounds=self.translation_bounds, + nonlin_std=self.nonlin_std, + nonlin_scale=self.nonlin_scale, + randomise_res=self.randomise_res, + max_res_iso=self.max_res_iso, + max_res_aniso=self.max_res_aniso, + data_res=self.data_res, + thickness=self.thickness, + bias_field_std=self.bias_field_std, + bias_scale=self.bias_scale, + return_gradients=self.return_gradients, + ) + out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] + return lab_to_im_model, out_shape + + def _build_model_inputs_generator(self, mix_prior_and_random): + # build model's inputs generator + model_inputs_generator = build_model_inputs( + path_label_maps=self.labels_paths, + n_labels=len(self.generation_labels), + batchsize=self.batchsize, + n_channels=self.n_channels, + subjects_prob=self.subjects_prob, + generation_classes=self.generation_classes, + prior_means=self.prior_means, + prior_stds=self.prior_stds, + prior_distributions=self.prior_distributions, + use_specific_stats_for_channel=self.use_specific_stats_for_channel, + mix_prior_and_random=mix_prior_and_random, + ) + return model_inputs_generator + + def _build_brain_generator(self): + while True: + model_inputs = next(self.model_inputs_generator) + [image, labels] = self.labels_to_image_model.predict(model_inputs) + yield image, labels + + def generate_brain(self): + """call this method when an object of this class has been instantiated to generate new brains""" + (image, labels) = next(self.brain_generator) + # put back images in native space + list_images = list() + list_labels = list() + for i in range(self.batchsize): + list_images.append( + edit_volumes.align_volume_to_ref( + image[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + list_labels.append( + edit_volumes.align_volume_to_ref( + labels[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + image = np.squeeze(np.stack(list_images, axis=0)) + labels = np.squeeze(np.stack(list_labels, axis=0)) + return image, labels diff --git a/nobrainer/processing/checkpoint.py b/nobrainer/processing/checkpoint.py index bf54ef43..8ab71692 100644 --- a/nobrainer/processing/checkpoint.py +++ b/nobrainer/processing/checkpoint.py @@ -16,6 +16,7 @@ def __init__(self, estimator, file_path, **kwargs): file_path: str, directory to/from which to save or load. """ self.estimator = estimator + self.last_epoch = 0 super().__init__(file_path, **kwargs) def _save_model(self, epoch, batch, logs): diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index bc0da104..da6a0edb 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -5,7 +5,7 @@ from .base import BaseEstimator from .. import losses -from ..dataset import get_dataset +from ..dataset import Dataset class ProgressiveGeneration(BaseEstimator): @@ -136,8 +136,8 @@ def _compile(): d_loss_fn=d_loss, ) - print(self.model_.generator.summary()) - print(self.model_.discriminator.summary()) + self.model_.generator.summary() + self.model_.discriminator.summary() for resolution, info in dataset_train.items(): if resolution < self.current_resolution_: @@ -147,16 +147,20 @@ def _compile(): if batch_size % self.strategy.num_replicas_in_sync: raise ValueError("batch size must be a multiple of the number of GPUs") - dataset = get_dataset( + dataset = Dataset.from_tfrecords( file_pattern=info.get("file_pattern"), - batch_size=batch_size, num_parallel_calls=num_parallel_calls, volume_shape=(resolution, resolution, resolution), - n_classes=1, - scalar_label=True, - normalizer=info.get("normalizer") or normalizer, + scalar_labels=True, ) + if info.get("normalizer") or normalizer: + dataset = dataset.normalize(normalizer) + + n_epochs = info.get("epochs") or epochs + dataset = dataset.repeat(n_epochs).batch(batch_size) + steps_per_epoch = dataset.get_steps_per_epoch() + with self.strategy.scope(): # grow the networks by one (2^x) resolution if resolution > self.current_resolution_: @@ -164,10 +168,6 @@ def _compile(): self.model_.discriminator.add_resolution() _compile() - steps_per_epoch = (info.get("epochs") or epochs) // info.get( - "batch_size" - ) - # save_best_only is set to False as it is an adversarial loss model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( str(model_dir), @@ -182,7 +182,7 @@ def _compile(): print("Transition phase") self.model_.fit( - dataset, + dataset.dataset, phase="transition", resolution=resolution, steps_per_epoch=steps_per_epoch, # necessary for repeat dataset @@ -191,7 +191,7 @@ def _compile(): print("Resolution phase") self.model_.fit( - dataset, + dataset.dataset, phase="resolution", resolution=resolution, steps_per_epoch=steps_per_epoch, diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 7ba9a32a..11aadab7 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -38,11 +38,6 @@ def __init__( self.volume_shape_ = None self.scalar_labels_ = None - def add_model(self, base_model, model_args=None): - """Add a segmentation model""" - self.base_model = base_model - self.model_args = model_args or {} - def fit( self, dataset_train, @@ -55,6 +50,7 @@ def fit( metrics=metrics.dice, callbacks=None, verbose=1, + initial_epoch=0, ): """Train a segmentation model""" # TODO: check validity of datasets @@ -62,7 +58,7 @@ def fit( batch_size = dataset_train.batch_size self.block_shape_ = dataset_train.block_shape self.volume_shape_ = dataset_train.volume_shape - self.scalar_labels_ = dataset_train.scalar_labels + # self.scalar_labels_ = dataset_train.scalar_labels n_classes = dataset_train.n_classes opt_args = opt_args or {} if optimizer is None: @@ -104,8 +100,12 @@ def _compile(): if callbacks is None: callbacks = [] + dataset_train.repeat(epochs) + dataset_validate.repeat(epochs) + if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) + initial_epoch = self.checkpoint_tracker.last_epoch self.model_.fit( dataset_train.dataset, epochs=epochs, @@ -116,6 +116,7 @@ def _compile(): ), callbacks=callbacks, verbose=verbose, + initial_epoch=initial_epoch, ) return self diff --git a/nobrainer/tfrecord.py b/nobrainer/tfrecord.py index 405ba87d..a4701fde 100644 --- a/nobrainer/tfrecord.py +++ b/nobrainer/tfrecord.py @@ -26,7 +26,6 @@ def write( to_ras=True, compressed=True, processes=None, - chunksize=1, multi_resolution=False, resolutions=None, verbose=1, @@ -53,15 +52,15 @@ def write( writing to multiple TFRecord files (i.e., `examples_per_shard` < `len(features_labels)`). If `None`, uses all available cores. - chunksize: int, multiprocessing chunksize. multi_resolution: boolean, if `True`, different tfrecords for each resolution in each shard resolutions: list of ints, if multi_resolution is `True`, set resolutions for which tfrecords are created. For example, [4, 8, 16, 32, 64, 128, 256] verbose: int, if 1, print progress bar. If 0, print nothing. """ n_examples = len(features_labels) - n_shards = math.ceil(n_examples / examples_per_shard) - shards = np.array_split(features_labels, n_shards) + shards = np.array_split( + features_labels, np.arange(examples_per_shard, n_examples, examples_per_shard) + ) # Test that the `filename_template` has a `shard` formatting key. try: @@ -80,7 +79,9 @@ def write( # This is the object that returns a protocol buffer string of the feature and label # on each iteration. It is pickle-able, unlike a generator. proto_iterators = [ - _ProtoIterator(s, multi_resolution=multi_resolution, resolutions=resolutions) + _ProtoIterator( + s, to_ras=to_ras, multi_resolution=multi_resolution, resolutions=resolutions + ) for s in shards ] # Set up positional arguments for the core writer function. @@ -90,14 +91,14 @@ def write( # Set keyword arguments so the resulting function accepts one positional argument. map_fn = functools.partial( _write_tfrecords, - compressed=True, + compressed=compressed, multi_resolution=multi_resolution, resolutions=resolutions, ) if processes is None: processes = get_num_parallel() - Parallel(n_jobs=processes, verbose=10)( + Parallel(n_jobs=processes, verbose=verbose)( delayed(__writer_func)(val, map_fn) for val in iterable ) from joblib.externals.loky import get_reusable_executor