From a534bc45ed57269a48cfe52fca97c1abfdbd42d1 Mon Sep 17 00:00:00 2001 From: Aakanksha Rana <40461936+Aakanksha-Rana@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:38:43 -0600 Subject: [PATCH 01/14] [WIP] ConvneXT Models for Classification and Segmentation (#210) * Update CHANGELOG.md [skip ci] * Create DepthwiseConv3d.py * for classification * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update convnext.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update DepthwiseConv3d.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update convnext.py * Update convnext.py fix typo --------- Co-authored-by: Satrajit Ghosh Co-authored-by: Nobrainer Bot Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: H Gazula Co-authored-by: Aakanksha Rana --- nobrainer/layers/DepthwiseConv3d.py | 338 ++++++++++++++++++++++++++++ nobrainer/models/convnext.py | 196 ++++++++++++++++ 2 files changed, 534 insertions(+) create mode 100644 nobrainer/layers/DepthwiseConv3d.py create mode 100644 nobrainer/models/convnext.py diff --git a/nobrainer/layers/DepthwiseConv3d.py b/nobrainer/layers/DepthwiseConv3d.py new file mode 100644 index 00000000..2cfd0efa --- /dev/null +++ b/nobrainer/layers/DepthwiseConv3d.py @@ -0,0 +1,338 @@ +""" +Directly taken from +https://github.com/alexandrosstergiou/keras-DepthwiseConv3D/blob/master/DepthwiseConv3D.py +This is a modification of the SeparableConv3D code in Keras, +to perform just the Depthwise Convolution (1st step) of the +Depthwise Separable Convolution layer. +""" +from __future__ import absolute_import + +from keras import backend as K +from keras import constraints, initializers, regularizers +from keras.backend.tensorflow_backend import ( + _preprocess_conv3d_input, + _preprocess_padding, +) +from keras.engine import InputSpec +from keras.layers import Conv3D +from keras.legacy.interfaces import conv3d_args_preprocessor +from keras.utils import conv_utils +import tensorflow as tf + + +def depthwise_conv3d_args_preprocessor(args, kwargs): + converted = [] + + if "init" in kwargs: + init = kwargs.pop("init") + kwargs["depthwise_initializer"] = init + converted.append(("init", "depthwise_initializer")) + + args, kwargs, _converted = conv3d_args_preprocessor(args, kwargs) + return args, kwargs, converted + _converted + + +# legacy_depthwise_conv3d_support = generate_legacy_interface( +# allowed_positional_args=["filters", "kernel_size"], +# conversions=[ +# ("nb_filter", "filters"), +# ("subsample", "strides"), +# ("border_mode", "padding"), +# ("dim_ordering", "data_format"), +# ("b_regularizer", "bias_regularizer"), +# ("b_constraint", "bias_constraint"), +# ("bias", "use_bias"), +# ], +# value_conversions={ +# "dim_ordering": { +# "tf": "channels_last", +# "th": "channels_first", +# "default": None, +# } +# }, +# preprocessor=depthwise_conv3d_args_preprocessor, +# ) + + +class DepthwiseConv3D(Conv3D): + """Depthwise 3D convolution. + Depth-wise part of separable convolutions consist in performing + just the first step/operation + (which acts on each input channel separately). + It does not perform the pointwise convolution (second step). + The `depth_multiplier` argument controls how many + output channels are generated per input channel in the depthwise step. + # Arguments + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, width and height of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, width and height. + Can be a single integer to specify the same value for + all spatial dimensions. + padding: one of `"valid"` or `"same"` (case-insensitive). + depth_multiplier: The number of depthwise convolution output channels + for each input channel. + The total number of depthwise convolution output + channels will be equal to `filterss_in * depth_multiplier`. + groups: The depth size of the convolution (as a variant of the original Depthwise conv) + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, height, width)`. + It defaults to the `image_data_format` value found in your + Keras config file at `~/.keras/keras.json`. + If you never set it, then it will be "channels_last". + activation: Activation function to use + (see [activations](../activations.md)). + If you don't specify anything, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + depthwise_initializer: Initializer for the depthwise kernel matrix + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + depthwise_regularizer: Regularizer function applied to + the depthwise kernel matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + dialation_rate: List of ints. + Defines the dilation factor for each dimension in the + input. Defaults to (1,1,1) + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + depthwise_constraint: Constraint function applied to + the depthwise kernel matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + # Input shape + 5D tensor with shape: + `(batch, depth, channels, rows, cols)` if data_format='channels_first' + or 5D tensor with shape: + `(batch, depth, rows, cols, channels)` if data_format='channels_last'. + # Output shape + 5D tensor with shape: + `(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first' + or 4D tensor with shape: + `(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'. + `rows` and `cols` values might have changed due to padding. + """ + + # @legacy_depthwise_conv3d_support + def __init__( + self, + kernel_size, + strides=(1, 1, 1), + padding="valid", + depth_multiplier=1, + groups=None, + data_format=None, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + dilation_rate=(1, 1, 1), + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs + ): + super(DepthwiseConv3D, self).__init__( + filters=None, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + bias_regularizer=bias_regularizer, + dilation_rate=dilation_rate, + activity_regularizer=activity_regularizer, + bias_constraint=bias_constraint, + **kwargs + ) + self.depth_multiplier = depth_multiplier + self.groups = groups + self.depthwise_initializer = initializers.get(depthwise_initializer) + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.bias_initializer = initializers.get(bias_initializer) + self.dilation_rate = dilation_rate + self._padding = _preprocess_padding(self.padding) + self._strides = (1,) + self.strides + (1,) + self._data_format = "NDHWC" + self.input_dim = None + + def build(self, input_shape): + if len(input_shape) < 5: + raise ValueError( + "Inputs to `DepthwiseConv3D` should have rank 5. " + "Received input shape:", + str(input_shape), + ) + if self.data_format == "channels_first": + channel_axis = 1 + else: + channel_axis = -1 + if input_shape[channel_axis] is None: + raise ValueError( + "The channel dimension of the inputs to " + "`DepthwiseConv3D` " + "should be defined. Found `None`." + ) + self.input_dim = int(input_shape[channel_axis]) + + if self.groups is None: + self.groups = self.input_dim + + if self.groups > self.input_dim: + raise ValueError( + "The number of groups cannot exceed the number of channels" + ) + + if self.input_dim % self.groups != 0: + raise ValueError( + "Warning! The channels dimension is not divisible by the group size chosen" + ) + + depthwise_kernel_shape = ( + self.kernel_size[0], + self.kernel_size[1], + self.kernel_size[2], + self.input_dim, + self.depth_multiplier, + ) + + self.depthwise_kernel = self.add_weight( + shape=depthwise_kernel_shape, + initializer=self.depthwise_initializer, + name="depthwise_kernel", + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint, + ) + + if self.use_bias: + self.bias = self.add_weight( + shape=(self.groups * self.depth_multiplier,), + initializer=self.bias_initializer, + name="bias", + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + ) + else: + self.bias = None + # Set input spec. + self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim}) + self.built = True + + def call(self, inputs, training=None): + inputs = _preprocess_conv3d_input(inputs, self.data_format) + + if self.data_format == "channels_last": + dilation = (1,) + self.dilation_rate + (1,) + else: + dilation = self.dilation_rate + (1,) + (1,) + + if self._data_format == "NCDHW": + outputs = tf.concat( + [ + tf.nn.conv3d( + inputs[0][:, i : i + self.input_dim // self.groups, :, :, :], + self.depthwise_kernel[ + :, :, :, i : i + self.input_dim // self.groups, : + ], + strides=self._strides, + padding=self._padding, + dilations=dilation, + data_format=self._data_format, + ) + for i in range(0, self.input_dim, self.input_dim // self.groups) + ], + axis=1, + ) + + else: + outputs = tf.concat( + [ + tf.nn.conv3d( + inputs[0][:, :, :, :, i : i + self.input_dim // self.groups], + self.depthwise_kernel[ + :, :, :, i : i + self.input_dim // self.groups, : + ], + strides=self._strides, + padding=self._padding, + dilations=dilation, + data_format=self._data_format, + ) + for i in range(0, self.input_dim, self.input_dim // self.groups) + ], + axis=-1, + ) + + if self.bias is not None: + outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + depth = input_shape[2] + rows = input_shape[3] + cols = input_shape[4] + out_filters = self.groups * self.depth_multiplier + elif self.data_format == "channels_last": + depth = input_shape[1] + rows = input_shape[2] + cols = input_shape[3] + out_filters = self.groups * self.depth_multiplier + + depth = conv_utils.conv_output_length( + depth, self.kernel_size[0], self.padding, self.strides[0] + ) + + rows = conv_utils.conv_output_length( + rows, self.kernel_size[1], self.padding, self.strides[1] + ) + + cols = conv_utils.conv_output_length( + cols, self.kernel_size[2], self.padding, self.strides[2] + ) + + if self.data_format == "channels_first": + return (input_shape[0], out_filters, depth, rows, cols) + + elif self.data_format == "channels_last": + return (input_shape[0], depth, rows, cols, out_filters) + + def get_config(self): + config = super(DepthwiseConv3D, self).get_config() + config.pop("filters") + config.pop("kernel_initializer") + config.pop("kernel_regularizer") + config.pop("kernel_constraint") + config["depth_multiplier"] = self.depth_multiplier + config["depthwise_initializer"] = initializers.serialize( + self.depthwise_initializer + ) + config["depthwise_regularizer"] = regularizers.serialize( + self.depthwise_regularizer + ) + config["depthwise_constraint"] = constraints.serialize( + self.depthwise_constraint + ) + return config + + +DepthwiseConvolution3D = DepthwiseConv3D diff --git a/nobrainer/models/convnext.py b/nobrainer/models/convnext.py new file mode 100644 index 00000000..b327670f --- /dev/null +++ b/nobrainer/models/convnext.py @@ -0,0 +1,196 @@ +import numpy as np +import tensorflow as tf +from tensorflow.keras import layers + +from ..layers.DepthwiseConv3d import DepthwiseConv3D + + +def drop_path(inputs, drop_prob, is_training): + # https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py + if (not is_training) or (drop_prob == 0.0): + return inputs + + # Compute keep_prob + keep_prob = 1.0 - drop_prob + + # Compute drop_connect tensor + random_tensor = keep_prob + shape = (tf.shape(inputs)[0],) + (1,) * (len(tf.shape(inputs)) - 1) + random_tensor += tf.random.uniform(shape, dtype=inputs.dtype) + binary_tensor = tf.floor(random_tensor) + output = tf.math.divide(inputs, keep_prob) * binary_tensor + return output + + +class DropPath(tf.keras.layers.Layer): + # https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def call(self, x, training=None): + return drop_path(x, self.drop_prob, training) + + +class Block(layers.Layer): + """ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) + -> 1x1x1 Conv -> GELU -> 1x1x1 Conv; all in (N, C, H, W, D) + (2) DwConv -> Permute to (N, H, W, D, C); + LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6, prefix=""): + super().__init__() + self.dwconv = DepthwiseConv3D(kernel_size=7, padding="same") # depthwise conv + self.norm = layers.LayerNormalization(epsilon=1e-6) + # pointwise/1x1x1 convs, implemented with linear layers + self.pwconv1 = layers.Dense(4 * dim) + self.act = tf.keras.activations.gelu + self.pwconv2 = layers.Dense(dim) + self.drop_path = DropPath(drop_path) + self.dim = dim + self.layer_scale_init_value = layer_scale_init_value + self.prefix = prefix + + def build(self, input_shape): + self.gamma = tf.Variable( + initial_value=self.layer_scale_init_value * tf.ones((self.dim)), + trainable=True, + name=f"{self.prefix}/gamma", + ) + self.built = True + + def call(self, x): + input = x + x = self.dwconv(x) + # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXt(tf.keras.Model): + """3D ConvNeXt Classification Model. + + Adapted from 2D Tensorflow keras impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + Args: + num_classes (int): Number of classes for classification head. Default: 1 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + include_top (bool): whether to add head or + just use it as feature extractor. Default: True + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for + classifier weights and biases. Default: 1. + """ + + def __init__( + self, + num_classes=1, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + include_top=True, + drop_path_rate=0.0, + layer_scale_init_value=1e-6, + head_init_scale=1.0, + ): + super().__init__() + self.include_top = include_top + self.downsample_layers = [] # stem and 3 intermediate downsampling conv layers + stem = tf.keras.Sequential( + [ + layers.Conv3D(dims[0], kernel_size=4, strides=4, padding="same"), + layers.LayerNormalization(epsilon=1e-6), + ] + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = tf.keras.Sequential( + [ + layers.LayerNormalization(epsilon=1e-6), + layers.Conv3D( + dims[i + 1], kernel_size=2, strides=2, padding="same" + ), + ] + ) + self.downsample_layers.append(downsample_layer) + + self.stages = ( + [] + ) # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates = [x for x in np.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = tf.keras.Sequential( + [ + Block( + dim=dims[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + prefix=f"block{i}", + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + if self.include_top: + self.avg = layers.GlobalAveragePooling3D() + self.norm = layers.LayerNormalization(epsilon=1e-6) # final norm layer + self.head = layers.Dense(num_classes) + else: + self.avg = None + self.norm = None + self.head = None + + def forward_features(self, x): + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return x + + def call(self, x): + x = self.forward_features(x) + if self.include_top: + x = self.avg(x) + x = self.norm(x) + x = self.head(x) + return x + + +model_configs = dict( + convnext_tiny=dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]), + convnext_small=dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]), + convnext_base=dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]), + convnext_large=dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536]), + convnext_xlarge=dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048]), +) + + +def create_model( + input_shape=(128, 128, 128, 1), + num_classes=1, + include_top=True, + model_name="convnext_tiny_1k", + **kwargs, +): + cfg = model_configs["_".join(model_name.split("_")[:2])] + net = ConvNeXt(num_classes, cfg["depths"], cfg["dims"], include_top, **kwargs) + net(tf.keras.Input(shape=input_shape)) + return net From ad24b265e631b947d2eed662e318e7d56be700c5 Mon Sep 17 00:00:00 2001 From: H Gazula Date: Fri, 22 Mar 2024 19:58:35 -0400 Subject: [PATCH 02/14] Revert "[WIP] ConvneXT Models for Classification and Segmentation (#210)" (#307) This reverts commit a534bc45ed57269a48cfe52fca97c1abfdbd42d1. --- nobrainer/layers/DepthwiseConv3d.py | 338 ---------------------------- nobrainer/models/convnext.py | 196 ---------------- 2 files changed, 534 deletions(-) delete mode 100644 nobrainer/layers/DepthwiseConv3d.py delete mode 100644 nobrainer/models/convnext.py diff --git a/nobrainer/layers/DepthwiseConv3d.py b/nobrainer/layers/DepthwiseConv3d.py deleted file mode 100644 index 2cfd0efa..00000000 --- a/nobrainer/layers/DepthwiseConv3d.py +++ /dev/null @@ -1,338 +0,0 @@ -""" -Directly taken from -https://github.com/alexandrosstergiou/keras-DepthwiseConv3D/blob/master/DepthwiseConv3D.py -This is a modification of the SeparableConv3D code in Keras, -to perform just the Depthwise Convolution (1st step) of the -Depthwise Separable Convolution layer. -""" -from __future__ import absolute_import - -from keras import backend as K -from keras import constraints, initializers, regularizers -from keras.backend.tensorflow_backend import ( - _preprocess_conv3d_input, - _preprocess_padding, -) -from keras.engine import InputSpec -from keras.layers import Conv3D -from keras.legacy.interfaces import conv3d_args_preprocessor -from keras.utils import conv_utils -import tensorflow as tf - - -def depthwise_conv3d_args_preprocessor(args, kwargs): - converted = [] - - if "init" in kwargs: - init = kwargs.pop("init") - kwargs["depthwise_initializer"] = init - converted.append(("init", "depthwise_initializer")) - - args, kwargs, _converted = conv3d_args_preprocessor(args, kwargs) - return args, kwargs, converted + _converted - - -# legacy_depthwise_conv3d_support = generate_legacy_interface( -# allowed_positional_args=["filters", "kernel_size"], -# conversions=[ -# ("nb_filter", "filters"), -# ("subsample", "strides"), -# ("border_mode", "padding"), -# ("dim_ordering", "data_format"), -# ("b_regularizer", "bias_regularizer"), -# ("b_constraint", "bias_constraint"), -# ("bias", "use_bias"), -# ], -# value_conversions={ -# "dim_ordering": { -# "tf": "channels_last", -# "th": "channels_first", -# "default": None, -# } -# }, -# preprocessor=depthwise_conv3d_args_preprocessor, -# ) - - -class DepthwiseConv3D(Conv3D): - """Depthwise 3D convolution. - Depth-wise part of separable convolutions consist in performing - just the first step/operation - (which acts on each input channel separately). - It does not perform the pointwise convolution (second step). - The `depth_multiplier` argument controls how many - output channels are generated per input channel in the depthwise step. - # Arguments - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, width and height of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, width and height. - Can be a single integer to specify the same value for - all spatial dimensions. - padding: one of `"valid"` or `"same"` (case-insensitive). - depth_multiplier: The number of depthwise convolution output channels - for each input channel. - The total number of depthwise convolution output - channels will be equal to `filterss_in * depth_multiplier`. - groups: The depth size of the convolution (as a variant of the original Depthwise conv) - data_format: A string, - one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, height, width)`. - It defaults to the `image_data_format` value found in your - Keras config file at `~/.keras/keras.json`. - If you never set it, then it will be "channels_last". - activation: Activation function to use - (see [activations](../activations.md)). - If you don't specify anything, no activation is applied - (ie. "linear" activation: `a(x) = x`). - use_bias: Boolean, whether the layer uses a bias vector. - depthwise_initializer: Initializer for the depthwise kernel matrix - (see [initializers](../initializers.md)). - bias_initializer: Initializer for the bias vector - (see [initializers](../initializers.md)). - depthwise_regularizer: Regularizer function applied to - the depthwise kernel matrix - (see [regularizer](../regularizers.md)). - bias_regularizer: Regularizer function applied to the bias vector - (see [regularizer](../regularizers.md)). - dialation_rate: List of ints. - Defines the dilation factor for each dimension in the - input. Defaults to (1,1,1) - activity_regularizer: Regularizer function applied to - the output of the layer (its "activation"). - (see [regularizer](../regularizers.md)). - depthwise_constraint: Constraint function applied to - the depthwise kernel matrix - (see [constraints](../constraints.md)). - bias_constraint: Constraint function applied to the bias vector - (see [constraints](../constraints.md)). - # Input shape - 5D tensor with shape: - `(batch, depth, channels, rows, cols)` if data_format='channels_first' - or 5D tensor with shape: - `(batch, depth, rows, cols, channels)` if data_format='channels_last'. - # Output shape - 5D tensor with shape: - `(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first' - or 4D tensor with shape: - `(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'. - `rows` and `cols` values might have changed due to padding. - """ - - # @legacy_depthwise_conv3d_support - def __init__( - self, - kernel_size, - strides=(1, 1, 1), - padding="valid", - depth_multiplier=1, - groups=None, - data_format=None, - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - dilation_rate=(1, 1, 1), - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super(DepthwiseConv3D, self).__init__( - filters=None, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - activation=activation, - use_bias=use_bias, - bias_regularizer=bias_regularizer, - dilation_rate=dilation_rate, - activity_regularizer=activity_regularizer, - bias_constraint=bias_constraint, - **kwargs - ) - self.depth_multiplier = depth_multiplier - self.groups = groups - self.depthwise_initializer = initializers.get(depthwise_initializer) - self.depthwise_regularizer = regularizers.get(depthwise_regularizer) - self.depthwise_constraint = constraints.get(depthwise_constraint) - self.bias_initializer = initializers.get(bias_initializer) - self.dilation_rate = dilation_rate - self._padding = _preprocess_padding(self.padding) - self._strides = (1,) + self.strides + (1,) - self._data_format = "NDHWC" - self.input_dim = None - - def build(self, input_shape): - if len(input_shape) < 5: - raise ValueError( - "Inputs to `DepthwiseConv3D` should have rank 5. " - "Received input shape:", - str(input_shape), - ) - if self.data_format == "channels_first": - channel_axis = 1 - else: - channel_axis = -1 - if input_shape[channel_axis] is None: - raise ValueError( - "The channel dimension of the inputs to " - "`DepthwiseConv3D` " - "should be defined. Found `None`." - ) - self.input_dim = int(input_shape[channel_axis]) - - if self.groups is None: - self.groups = self.input_dim - - if self.groups > self.input_dim: - raise ValueError( - "The number of groups cannot exceed the number of channels" - ) - - if self.input_dim % self.groups != 0: - raise ValueError( - "Warning! The channels dimension is not divisible by the group size chosen" - ) - - depthwise_kernel_shape = ( - self.kernel_size[0], - self.kernel_size[1], - self.kernel_size[2], - self.input_dim, - self.depth_multiplier, - ) - - self.depthwise_kernel = self.add_weight( - shape=depthwise_kernel_shape, - initializer=self.depthwise_initializer, - name="depthwise_kernel", - regularizer=self.depthwise_regularizer, - constraint=self.depthwise_constraint, - ) - - if self.use_bias: - self.bias = self.add_weight( - shape=(self.groups * self.depth_multiplier,), - initializer=self.bias_initializer, - name="bias", - regularizer=self.bias_regularizer, - constraint=self.bias_constraint, - ) - else: - self.bias = None - # Set input spec. - self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim}) - self.built = True - - def call(self, inputs, training=None): - inputs = _preprocess_conv3d_input(inputs, self.data_format) - - if self.data_format == "channels_last": - dilation = (1,) + self.dilation_rate + (1,) - else: - dilation = self.dilation_rate + (1,) + (1,) - - if self._data_format == "NCDHW": - outputs = tf.concat( - [ - tf.nn.conv3d( - inputs[0][:, i : i + self.input_dim // self.groups, :, :, :], - self.depthwise_kernel[ - :, :, :, i : i + self.input_dim // self.groups, : - ], - strides=self._strides, - padding=self._padding, - dilations=dilation, - data_format=self._data_format, - ) - for i in range(0, self.input_dim, self.input_dim // self.groups) - ], - axis=1, - ) - - else: - outputs = tf.concat( - [ - tf.nn.conv3d( - inputs[0][:, :, :, :, i : i + self.input_dim // self.groups], - self.depthwise_kernel[ - :, :, :, i : i + self.input_dim // self.groups, : - ], - strides=self._strides, - padding=self._padding, - dilations=dilation, - data_format=self._data_format, - ) - for i in range(0, self.input_dim, self.input_dim // self.groups) - ], - axis=-1, - ) - - if self.bias is not None: - outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - def compute_output_shape(self, input_shape): - if self.data_format == "channels_first": - depth = input_shape[2] - rows = input_shape[3] - cols = input_shape[4] - out_filters = self.groups * self.depth_multiplier - elif self.data_format == "channels_last": - depth = input_shape[1] - rows = input_shape[2] - cols = input_shape[3] - out_filters = self.groups * self.depth_multiplier - - depth = conv_utils.conv_output_length( - depth, self.kernel_size[0], self.padding, self.strides[0] - ) - - rows = conv_utils.conv_output_length( - rows, self.kernel_size[1], self.padding, self.strides[1] - ) - - cols = conv_utils.conv_output_length( - cols, self.kernel_size[2], self.padding, self.strides[2] - ) - - if self.data_format == "channels_first": - return (input_shape[0], out_filters, depth, rows, cols) - - elif self.data_format == "channels_last": - return (input_shape[0], depth, rows, cols, out_filters) - - def get_config(self): - config = super(DepthwiseConv3D, self).get_config() - config.pop("filters") - config.pop("kernel_initializer") - config.pop("kernel_regularizer") - config.pop("kernel_constraint") - config["depth_multiplier"] = self.depth_multiplier - config["depthwise_initializer"] = initializers.serialize( - self.depthwise_initializer - ) - config["depthwise_regularizer"] = regularizers.serialize( - self.depthwise_regularizer - ) - config["depthwise_constraint"] = constraints.serialize( - self.depthwise_constraint - ) - return config - - -DepthwiseConvolution3D = DepthwiseConv3D diff --git a/nobrainer/models/convnext.py b/nobrainer/models/convnext.py deleted file mode 100644 index b327670f..00000000 --- a/nobrainer/models/convnext.py +++ /dev/null @@ -1,196 +0,0 @@ -import numpy as np -import tensorflow as tf -from tensorflow.keras import layers - -from ..layers.DepthwiseConv3d import DepthwiseConv3D - - -def drop_path(inputs, drop_prob, is_training): - # https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py - if (not is_training) or (drop_prob == 0.0): - return inputs - - # Compute keep_prob - keep_prob = 1.0 - drop_prob - - # Compute drop_connect tensor - random_tensor = keep_prob - shape = (tf.shape(inputs)[0],) + (1,) * (len(tf.shape(inputs)) - 1) - random_tensor += tf.random.uniform(shape, dtype=inputs.dtype) - binary_tensor = tf.floor(random_tensor) - output = tf.math.divide(inputs, keep_prob) * binary_tensor - return output - - -class DropPath(tf.keras.layers.Layer): - # https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py - def __init__(self, drop_prob=None): - super().__init__() - self.drop_prob = drop_prob - - def call(self, x, training=None): - return drop_path(x, self.drop_prob, training) - - -class Block(layers.Layer): - """ConvNeXt Block. There are two equivalent implementations: - (1) DwConv -> LayerNorm (channels_first) - -> 1x1x1 Conv -> GELU -> 1x1x1 Conv; all in (N, C, H, W, D) - (2) DwConv -> Permute to (N, H, W, D, C); - LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back - We use (2) as we find it slightly faster in PyTorch - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - """ - - def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6, prefix=""): - super().__init__() - self.dwconv = DepthwiseConv3D(kernel_size=7, padding="same") # depthwise conv - self.norm = layers.LayerNormalization(epsilon=1e-6) - # pointwise/1x1x1 convs, implemented with linear layers - self.pwconv1 = layers.Dense(4 * dim) - self.act = tf.keras.activations.gelu - self.pwconv2 = layers.Dense(dim) - self.drop_path = DropPath(drop_path) - self.dim = dim - self.layer_scale_init_value = layer_scale_init_value - self.prefix = prefix - - def build(self, input_shape): - self.gamma = tf.Variable( - initial_value=self.layer_scale_init_value * tf.ones((self.dim)), - trainable=True, - name=f"{self.prefix}/gamma", - ) - self.built = True - - def call(self, x): - input = x - x = self.dwconv(x) - # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - if self.gamma is not None: - x = self.gamma * x - # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) - - x = input + self.drop_path(x) - return x - - -class ConvNeXt(tf.keras.Model): - """3D ConvNeXt Classification Model. - - Adapted from 2D Tensorflow keras impl of : `A ConvNet for the 2020s` - - https://arxiv.org/pdf/2201.03545.pdf - Args: - num_classes (int): Number of classes for classification head. Default: 1 - depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] - dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] - include_top (bool): whether to add head or - just use it as feature extractor. Default: True - drop_path_rate (float): Stochastic depth rate. Default: 0. - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - head_init_scale (float): Init scaling value for - classifier weights and biases. Default: 1. - """ - - def __init__( - self, - num_classes=1, - depths=[3, 3, 9, 3], - dims=[96, 192, 384, 768], - include_top=True, - drop_path_rate=0.0, - layer_scale_init_value=1e-6, - head_init_scale=1.0, - ): - super().__init__() - self.include_top = include_top - self.downsample_layers = [] # stem and 3 intermediate downsampling conv layers - stem = tf.keras.Sequential( - [ - layers.Conv3D(dims[0], kernel_size=4, strides=4, padding="same"), - layers.LayerNormalization(epsilon=1e-6), - ] - ) - self.downsample_layers.append(stem) - for i in range(3): - downsample_layer = tf.keras.Sequential( - [ - layers.LayerNormalization(epsilon=1e-6), - layers.Conv3D( - dims[i + 1], kernel_size=2, strides=2, padding="same" - ), - ] - ) - self.downsample_layers.append(downsample_layer) - - self.stages = ( - [] - ) # 4 feature resolution stages, each consisting of multiple residual blocks - dp_rates = [x for x in np.linspace(0, drop_path_rate, sum(depths))] - cur = 0 - for i in range(4): - stage = tf.keras.Sequential( - [ - Block( - dim=dims[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_value, - prefix=f"block{i}", - ) - for j in range(depths[i]) - ] - ) - self.stages.append(stage) - cur += depths[i] - - if self.include_top: - self.avg = layers.GlobalAveragePooling3D() - self.norm = layers.LayerNormalization(epsilon=1e-6) # final norm layer - self.head = layers.Dense(num_classes) - else: - self.avg = None - self.norm = None - self.head = None - - def forward_features(self, x): - for i in range(4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - return x - - def call(self, x): - x = self.forward_features(x) - if self.include_top: - x = self.avg(x) - x = self.norm(x) - x = self.head(x) - return x - - -model_configs = dict( - convnext_tiny=dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]), - convnext_small=dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768]), - convnext_base=dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]), - convnext_large=dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536]), - convnext_xlarge=dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048]), -) - - -def create_model( - input_shape=(128, 128, 128, 1), - num_classes=1, - include_top=True, - model_name="convnext_tiny_1k", - **kwargs, -): - cfg = model_configs["_".join(model_name.split("_")[:2])] - net = ConvNeXt(num_classes, cfg["depths"], cfg["dims"], include_top, **kwargs) - net(tf.keras.Input(shape=input_shape)) - return net From 093daff551fc49b1ef697ac60bbb0e4bfe97485e Mon Sep 17 00:00:00 2001 From: Harsha Date: Fri, 22 Mar 2024 22:21:11 -0400 Subject: [PATCH 03/14] resolved https://github.com/neuronets/nobrainer/issues/308#issue-2203645742 --- nobrainer/processing/generation.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index bc0da104..2d4b93f8 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -5,7 +5,7 @@ from .base import BaseEstimator from .. import losses -from ..dataset import get_dataset +from ..dataset import Dataset class ProgressiveGeneration(BaseEstimator): @@ -136,8 +136,8 @@ def _compile(): d_loss_fn=d_loss, ) - print(self.model_.generator.summary()) - print(self.model_.discriminator.summary()) + self.model_.generator.summary() + self.model_.discriminator.summary() for resolution, info in dataset_train.items(): if resolution < self.current_resolution_: @@ -147,16 +147,16 @@ def _compile(): if batch_size % self.strategy.num_replicas_in_sync: raise ValueError("batch size must be a multiple of the number of GPUs") - dataset = get_dataset( + dataset = Dataset.from_tfrecords( file_pattern=info.get("file_pattern"), - batch_size=batch_size, num_parallel_calls=num_parallel_calls, volume_shape=(resolution, resolution, resolution), - n_classes=1, - scalar_label=True, - normalizer=info.get("normalizer") or normalizer, + scalar_labels=True, ) + if info.get("normalizer") or normalizer: + dataset = dataset.normalize(normalizer) + with self.strategy.scope(): # grow the networks by one (2^x) resolution if resolution > self.current_resolution_: @@ -164,9 +164,7 @@ def _compile(): self.model_.discriminator.add_resolution() _compile() - steps_per_epoch = (info.get("epochs") or epochs) // info.get( - "batch_size" - ) + steps_per_epoch = dataset.get_steps_per_epoch() # save_best_only is set to False as it is an adversarial loss model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( From 4991db681d1ab49373d86794354a82d2d3c6ec52 Mon Sep 17 00:00:00 2001 From: H Gazula Date: Sat, 23 Mar 2024 09:48:56 -0400 Subject: [PATCH 04/14] resolved https://github.com/neuronets/nobrainer/issues/310 --- nobrainer/processing/segmentation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 7ba9a32a..c6805443 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -38,10 +38,6 @@ def __init__( self.volume_shape_ = None self.scalar_labels_ = None - def add_model(self, base_model, model_args=None): - """Add a segmentation model""" - self.base_model = base_model - self.model_args = model_args or {} def fit( self, From 66edbb257b57dc287ca206ef9cf9c3ed5b485289 Mon Sep 17 00:00:00 2001 From: Harsha Date: Sat, 23 Mar 2024 11:32:59 -0400 Subject: [PATCH 05/14] resolved https://github.com/neuronets/nobrainer/issues/308#issuecomment-2016311834 --- nobrainer/processing/generation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index 2d4b93f8..da6a0edb 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -157,6 +157,10 @@ def _compile(): if info.get("normalizer") or normalizer: dataset = dataset.normalize(normalizer) + n_epochs = info.get("epochs") or epochs + dataset = dataset.repeat(n_epochs).batch(batch_size) + steps_per_epoch = dataset.get_steps_per_epoch() + with self.strategy.scope(): # grow the networks by one (2^x) resolution if resolution > self.current_resolution_: @@ -164,8 +168,6 @@ def _compile(): self.model_.discriminator.add_resolution() _compile() - steps_per_epoch = dataset.get_steps_per_epoch() - # save_best_only is set to False as it is an adversarial loss model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( str(model_dir), @@ -180,7 +182,7 @@ def _compile(): print("Transition phase") self.model_.fit( - dataset, + dataset.dataset, phase="transition", resolution=resolution, steps_per_epoch=steps_per_epoch, # necessary for repeat dataset @@ -189,7 +191,7 @@ def _compile(): print("Resolution phase") self.model_.fit( - dataset, + dataset.dataset, phase="resolution", resolution=resolution, steps_per_epoch=steps_per_epoch, From 690b68c755bcb3ecdcafed759ca2ea4688b33a58 Mon Sep 17 00:00:00 2001 From: Harsha Date: Tue, 2 Apr 2024 05:01:23 -0400 Subject: [PATCH 06/14] resolved https://github.com/neuronets/nobrainer/issues/314 also disable calculating scalar_labels, for more info https://github.com/neuronets/nobrainer/issues/313 --- nobrainer/processing/segmentation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index c6805443..4b57e469 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -38,7 +38,6 @@ def __init__( self.volume_shape_ = None self.scalar_labels_ = None - def fit( self, dataset_train, @@ -58,7 +57,7 @@ def fit( batch_size = dataset_train.batch_size self.block_shape_ = dataset_train.block_shape self.volume_shape_ = dataset_train.volume_shape - self.scalar_labels_ = dataset_train.scalar_labels + # self.scalar_labels_ = dataset_train.scalar_labels n_classes = dataset_train.n_classes opt_args = opt_args or {} if optimizer is None: @@ -100,6 +99,9 @@ def _compile(): if callbacks is None: callbacks = [] + dataset_train.repeat(epochs) + dataset_validate.repeat(epochs) + if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) self.model_.fit( From b85e6a953e5b5540c1410e84596b9ebda6161047 Mon Sep 17 00:00:00 2001 From: Harsha Date: Tue, 2 Apr 2024 18:24:19 -0400 Subject: [PATCH 07/14] Resolved https://github.com/neuronets/nobrainer/issues/315, https://github.com/neuronets/nobrainer/issues/316, https://github.com/neuronets/nobrainer/issues/317 --- nobrainer/dataset.py | 40 +++++++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 30c5be53..3b7b010d 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -69,6 +69,9 @@ def __init__( self.volume_shape = volume_shape self.n_classes = n_classes + if self.n_classes < 1: + raise ValueError("n_classes must be > 0.") + @classmethod def from_tfrecords( cls, @@ -80,6 +83,7 @@ def from_tfrecords( n_classes=1, tf_dataset_options=None, num_parallel_calls=1, + label_mapping=None, ): """Function to retrieve a saved tf record as a nobrainer Dataset @@ -123,6 +127,7 @@ def from_tfrecords( if not n_volumes: n_volumes = block_length * len(files) + print(f"n_volumes: {n_volumes}") dataset = dataset.interleave( map_func=lambda x: tf.data.TFRecordDataset( @@ -138,7 +143,9 @@ def from_tfrecords( if block_shape: ds_obj.block(block_shape) if not scalar_labels: - ds_obj.map_labels() + ds_obj.map_labels( + label_mapping=label_mapping, num_parallel_calls=num_parallel_calls + ) # TODO automatically determine batch size ds_obj.batch(1) @@ -158,6 +165,7 @@ def from_files( n_classes=1, block_shape=None, tf_dataset_options=None, + label_mapping=None, ): """Create Nobrainer datasets from data filepaths: List(str), list of paths to individual input data files. @@ -221,6 +229,7 @@ def from_files( block_shape=block_shape, tf_dataset_options=tf_dataset_options, num_parallel_calls=num_parallel_calls, + label_mapping=label_mapping, ) ds_eval = None if n_eval > 0: @@ -234,6 +243,7 @@ def from_files( block_shape=block_shape, tf_dataset_options=tf_dataset_options, num_parallel_calls=num_parallel_calls, + label_mapping=label_mapping, ) return ds_train, ds_eval @@ -315,19 +325,31 @@ def _f(x, y): self.dataset = self.dataset.unbatch() return self - def map_labels(self, label_mapping=None): - if self.n_classes < 1: - raise ValueError("n_classes must be > 0.") - + def map_labels(self, label_mapping=None, num_parallel_calls=1): if label_mapping is not None: - self.map(lambda x, y: (x, replace(y, label_mapping=label_mapping))) + self.map(lambda x, y: (x, replace(y, mapping=label_mapping))) if self.n_classes == 1: - self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1))) + self.map( + lambda x, y: (x, tf.expand_dims(binarize(y), -1)), + num_parallel_calls=num_parallel_calls, + ) elif self.n_classes == 2: - self.map(lambda x, y: (x, tf.one_hot(binarize(y), self.n_classes))) + self.map( + lambda x, y: ( + x, + tf.one_hot(tf.cast(binarize(y), dtype=tf.int32), self.n_classes), + ), + num_parallel_calls=num_parallel_calls, + ) elif self.n_classes > 2: - self.map(lambda x, y: (x, tf.one_hot(y, self.n_classes))) + self.map( + lambda x, y: ( + x, + tf.one_hot(tf.cast(y, dtype=tf.int32), self.n_classes), + ), + num_parallel_calls=num_parallel_calls, + ) return self From 70b663be3bbb278815e598c9698603639e0ec25d Mon Sep 17 00:00:00 2001 From: Harsha Date: Thu, 25 Apr 2024 10:52:00 -0400 Subject: [PATCH 08/14] resolved https://github.com/neuronets/nobrainer/issues/329 --- nobrainer/tfrecord.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/nobrainer/tfrecord.py b/nobrainer/tfrecord.py index 405ba87d..0f3bbd95 100644 --- a/nobrainer/tfrecord.py +++ b/nobrainer/tfrecord.py @@ -26,7 +26,6 @@ def write( to_ras=True, compressed=True, processes=None, - chunksize=1, multi_resolution=False, resolutions=None, verbose=1, @@ -53,15 +52,13 @@ def write( writing to multiple TFRecord files (i.e., `examples_per_shard` < `len(features_labels)`). If `None`, uses all available cores. - chunksize: int, multiprocessing chunksize. multi_resolution: boolean, if `True`, different tfrecords for each resolution in each shard resolutions: list of ints, if multi_resolution is `True`, set resolutions for which tfrecords are created. For example, [4, 8, 16, 32, 64, 128, 256] verbose: int, if 1, print progress bar. If 0, print nothing. """ n_examples = len(features_labels) - n_shards = math.ceil(n_examples / examples_per_shard) - shards = np.array_split(features_labels, n_shards) + shards = np.array_split(features_labels, np.arange(examples_per_shard, n_examples, examples_per_shard)) # Test that the `filename_template` has a `shard` formatting key. try: @@ -80,7 +77,7 @@ def write( # This is the object that returns a protocol buffer string of the feature and label # on each iteration. It is pickle-able, unlike a generator. proto_iterators = [ - _ProtoIterator(s, multi_resolution=multi_resolution, resolutions=resolutions) + _ProtoIterator(s, to_ras=to_ras, multi_resolution=multi_resolution, resolutions=resolutions) for s in shards ] # Set up positional arguments for the core writer function. @@ -90,14 +87,14 @@ def write( # Set keyword arguments so the resulting function accepts one positional argument. map_fn = functools.partial( _write_tfrecords, - compressed=True, + compressed=compressed, multi_resolution=multi_resolution, resolutions=resolutions, ) if processes is None: processes = get_num_parallel() - Parallel(n_jobs=processes, verbose=10)( + Parallel(n_jobs=processes, verbose=verbose)( delayed(__writer_func)(val, map_fn) for val in iterable ) from joblib.externals.loky import get_reusable_executor From d01e48a5e9e97ee989daa853fcfd3ee731faa1a3 Mon Sep 17 00:00:00 2001 From: Harsha Date: Sat, 11 May 2024 22:51:58 -0400 Subject: [PATCH 09/14] resolved https://github.com/neuronets/nobrainer/issues/334 --- nobrainer/processing/checkpoint.py | 1 + nobrainer/processing/segmentation.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/nobrainer/processing/checkpoint.py b/nobrainer/processing/checkpoint.py index bf54ef43..8ab71692 100644 --- a/nobrainer/processing/checkpoint.py +++ b/nobrainer/processing/checkpoint.py @@ -16,6 +16,7 @@ def __init__(self, estimator, file_path, **kwargs): file_path: str, directory to/from which to save or load. """ self.estimator = estimator + self.last_epoch = 0 super().__init__(file_path, **kwargs) def _save_model(self, epoch, batch, logs): diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 4b57e469..58d0e03e 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -50,6 +50,7 @@ def fit( metrics=metrics.dice, callbacks=None, verbose=1, + initial_epoch=0 ): """Train a segmentation model""" # TODO: check validity of datasets @@ -104,6 +105,7 @@ def _compile(): if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) + initial_epoch = self.checkpoint_tracker.last_epoch self.model_.fit( dataset_train.dataset, epochs=epochs, @@ -114,6 +116,7 @@ def _compile(): ), callbacks=callbacks, verbose=verbose, + initial_epoch=initial_epoch ) return self From a8e960b3b24680516ba0c3d01ed23f0863d59455 Mon Sep 17 00:00:00 2001 From: Harsha Date: Tue, 21 May 2024 11:27:45 -0400 Subject: [PATCH 10/14] add bayesian vnet to list of available models --- nobrainer/models/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index a4842bc7..68a797d3 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -4,6 +4,7 @@ from .attention_unet_with_inception import attention_unet_with_inception from .autoencoder import autoencoder from .bayesian_meshnet import variational_meshnet +from .bayesian_vnet import bayesian_vnet from .dcgan import dcgan from .highresnet import highresnet from .meshnet import meshnet @@ -26,6 +27,7 @@ "attention_unet_with_inception": attention_unet_with_inception, "unetr": unetr, "variational_meshnet": variational_meshnet, + "bayesian_vnet": bayesian_vnet } From f0396cbb3dfeaa1dcab356f92dfa0f0999e334bc Mon Sep 17 00:00:00 2001 From: Harsha Date: Mon, 10 Jun 2024 09:01:19 -0400 Subject: [PATCH 11/14] resolved https://github.com/neuronets/nobrainer/issues/336 --- nobrainer/dataset.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 3b7b010d..f2d6e8bb 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -146,6 +146,8 @@ def from_tfrecords( ds_obj.map_labels( label_mapping=label_mapping, num_parallel_calls=num_parallel_calls ) + + ds_obj.filter_zero_volumes() # TODO automatically determine batch size ds_obj.batch(1) @@ -385,3 +387,9 @@ def repeat(self, n_repeats): # through once. self.dataset = self.dataset.repeat(n_repeats) return self + + def filter_zero_volumes(self): + self.dataset = self.dataset.filter( + lambda x, y: tf.cast(tf.math.reduce_sum(y), dtype="bool") + ) + return self From f9def401a627979b6a13c29b17d088a6e106e1dd Mon Sep 17 00:00:00 2001 From: Harsha Date: Wed, 19 Jun 2024 12:39:33 -0400 Subject: [PATCH 12/14] success: generation --- nobrainer/ext/SynthSeg/__init__.py | 1 + nobrainer/ext/SynthSeg/model_inputs.py | 163 ++ nobrainer/ext/__init__.py | 0 nobrainer/ext/lab2im/__init__.py | 6 + nobrainer/ext/lab2im/edit_tensors.py | 346 +++ nobrainer/ext/lab2im/edit_volumes.py | 2836 +++++++++++++++++++++ nobrainer/ext/lab2im/image_generator.py | 266 ++ nobrainer/ext/lab2im/lab2im_model.py | 174 ++ nobrainer/ext/lab2im/layers.py | 2060 +++++++++++++++ nobrainer/ext/lab2im/utils.py | 1057 ++++++++ nobrainer/ext/neuron/__init__.py | 3 + nobrainer/ext/neuron/layers.py | 435 ++++ nobrainer/ext/neuron/models.py | 768 ++++++ nobrainer/ext/neuron/utils.py | 548 ++++ nobrainer/models/__init__.py | 4 +- nobrainer/models/labels_to_image_model.py | 297 +++ nobrainer/processing/brain_generator.py | 335 +++ 17 files changed, 9298 insertions(+), 1 deletion(-) create mode 100644 nobrainer/ext/SynthSeg/__init__.py create mode 100644 nobrainer/ext/SynthSeg/model_inputs.py create mode 100644 nobrainer/ext/__init__.py create mode 100644 nobrainer/ext/lab2im/__init__.py create mode 100644 nobrainer/ext/lab2im/edit_tensors.py create mode 100644 nobrainer/ext/lab2im/edit_volumes.py create mode 100644 nobrainer/ext/lab2im/image_generator.py create mode 100644 nobrainer/ext/lab2im/lab2im_model.py create mode 100644 nobrainer/ext/lab2im/layers.py create mode 100644 nobrainer/ext/lab2im/utils.py create mode 100644 nobrainer/ext/neuron/__init__.py create mode 100644 nobrainer/ext/neuron/layers.py create mode 100644 nobrainer/ext/neuron/models.py create mode 100644 nobrainer/ext/neuron/utils.py create mode 100644 nobrainer/models/labels_to_image_model.py create mode 100644 nobrainer/processing/brain_generator.py diff --git a/nobrainer/ext/SynthSeg/__init__.py b/nobrainer/ext/SynthSeg/__init__.py new file mode 100644 index 00000000..151b2e92 --- /dev/null +++ b/nobrainer/ext/SynthSeg/__init__.py @@ -0,0 +1 @@ +from . import model_inputs diff --git a/nobrainer/ext/SynthSeg/model_inputs.py b/nobrainer/ext/SynthSeg/model_inputs.py new file mode 100644 index 00000000..34f4b4ad --- /dev/null +++ b/nobrainer/ext/SynthSeg/model_inputs.py @@ -0,0 +1,163 @@ +""" +If you use this code, please cite one of the SynthSeg papers: +https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import numpy as np +import numpy.random as npr + +# third-party imports +from nobrainer.ext.lab2im import utils + + +def build_model_inputs(path_label_maps, + n_labels, + batchsize=1, + n_channels=1, + subjects_prob=None, + generation_classes=None, + prior_distributions='uniform', + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False): + """ + This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the + input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch). + :param path_label_maps: list of the paths of the input label maps. + :param n_labels: number of labels in the input label maps. + :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. + :param n_channels: (optional) number of channels to be synthesised. Default is 1. + :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to pick + the provided label maps at each minibatch. Must be a 1D numpy array, as long as path_label_maps. + :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity + distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence or a + 1d numpy array. It should have the same length as generation_labels, and contain values between 0 and K-1, where K + is the total number of classes. Default is all labels have different classes. + :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. + Can either be 'uniform', or 'normal'. Default is 'uniform'. + :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because + these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: + 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is + uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each + mini_batch from the same distribution. + 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is + not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each mini-batch + from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, or from + N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. + 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived + from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a + modality from the n_mod possibilities, and we sample the GMM means like in 2). + If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel + (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. + 4) the path to such a numpy array. + Default is None, which corresponds to prior_means = [25, 225]. + :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. + Default is None, which corresponds to prior_stds = [5, 25]. + :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be + only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. + :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default + values for half of these cases, and thus generate images of random contrast. + """ + + # allocate unique class to each label if generation classes is not given + if generation_classes is None: + generation_classes = np.arange(n_labels) + n_classes = len(np.unique(generation_classes)) + + # make sure subjects_prob sums to 1 + subjects_prob = utils.load_array_if_path(subjects_prob) + if subjects_prob is not None: + subjects_prob /= np.sum(subjects_prob) + + # Generate! + while True: + + # randomly pick as many images as batchsize + indices = npr.choice(np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob) + + # initialise input lists + list_label_maps = [] + list_means = [] + list_stds = [] + + for idx in indices: + + # load input label map + lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4)) + if (npr.uniform() > 0.7) & ('seg_cerebral' in path_label_maps[idx]): + lab[lab == 24] = 0 + + # add label map to inputs + list_label_maps.append(utils.add_axis(lab, axis=[0, -1])) + + # add means and standard deviations to inputs + means = np.empty((1, n_labels, 0)) + stds = np.empty((1, n_labels, 0)) + for channel in range(n_channels): + + # retrieve channel specific stats if necessary + if isinstance(prior_means, np.ndarray): + if (prior_means.shape[0] > 2) & use_specific_stats_for_channel: + if prior_means.shape[0] / 2 != n_channels: + raise ValueError("the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True.") + tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :] + else: + tmp_prior_means = prior_means + else: + tmp_prior_means = prior_means + if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5): + tmp_prior_means = None + if isinstance(prior_stds, np.ndarray): + if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel: + if prior_stds.shape[0] / 2 != n_channels: + raise ValueError("the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True.") + tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :] + else: + tmp_prior_stds = prior_stds + else: + tmp_prior_stds = prior_stds + if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5): + tmp_prior_stds = None + + # draw means and std devs from priors + tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_classes, prior_distributions, + 125., 125., positive_only=True) + tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_classes, prior_distributions, + 15., 15., positive_only=True) + random_coef = npr.uniform() + if random_coef > 0.95: # reset the background to 0 in 5% of cases + tmp_classes_means[0] = 0 + tmp_classes_stds[0] = 0 + elif random_coef > 0.7: # reset the background to low Gaussian in 25% of cases + tmp_classes_means[0] = npr.uniform(0, 15) + tmp_classes_stds[0] = npr.uniform(0, 5) + tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1]) + tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1]) + means = np.concatenate([means, tmp_means], axis=-1) + stds = np.concatenate([stds, tmp_stds], axis=-1) + list_means.append(means) + list_stds.append(stds) + + # build list of inputs for generation model + list_inputs = [list_label_maps, list_means, list_stds] + if batchsize > 1: # concatenate each input type if batchsize > 1 + list_inputs = [np.concatenate(item, 0) for item in list_inputs] + else: + list_inputs = [item[0] for item in list_inputs] + + yield list_inputs diff --git a/nobrainer/ext/__init__.py b/nobrainer/ext/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nobrainer/ext/lab2im/__init__.py b/nobrainer/ext/lab2im/__init__.py new file mode 100644 index 00000000..d3fd52d0 --- /dev/null +++ b/nobrainer/ext/lab2im/__init__.py @@ -0,0 +1,6 @@ +from . import edit_tensors +from . import edit_volumes +from . import image_generator +from . import lab2im_model +from . import layers +from . import utils diff --git a/nobrainer/ext/lab2im/edit_tensors.py b/nobrainer/ext/lab2im/edit_tensors.py new file mode 100644 index 00000000..65e72a02 --- /dev/null +++ b/nobrainer/ext/lab2im/edit_tensors.py @@ -0,0 +1,346 @@ +""" + +This file contains functions to handle keras/tensorflow tensors. + - blurring_sigma_for_downsampling + - gaussian_kernel + - resample_tensor + - expand_dims + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. + +""" + + +# python imports +import numpy as np +import tensorflow as tf +import keras.layers as KL +import keras.backend as K +from itertools import combinations + +# project imports +from nobrainer.ext.lab2im import utils + +# third-party imports +import nobrainer.ext.neuron.layers as nrn_layers +from nobrainer.ext.neuron.utils import volshape_to_meshgrid + + +def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, thickness=None): + """Compute standard deviations of 1d gaussian masks for image blurring before downsampling. + :param downsample_res: resolution to downsample to. Can be a 1d numpy array or list, or a tensor. + :param current_res: resolution of the volume before downsampling. + Can be a 1d numpy array or list or tensor of the same length as downsample res. + :param mult_coef: (optional) multiplicative coefficient for the blurring kernel. Default is 0.75. + :param thickness: (optional) slice thickness in each dimension. Must be the same type as downsample_res. + :return: standard deviation of the blurring masks given as the same type as downsample_res (list or tensor). + """ + + if not tf.is_tensor(downsample_res): + + # get blurring resolution (min between downsample_res and thickness) + current_res = np.array(current_res) + downsample_res = np.array(downsample_res) + if thickness is not None: + downsample_res = np.minimum(downsample_res, np.array(thickness)) + + # get std deviation for blurring kernels + if mult_coef is None: + sigma = 0.75 * downsample_res / current_res + sigma[downsample_res == current_res] = 0.5 + else: + sigma = mult_coef * downsample_res / current_res + sigma[downsample_res == 0] = 0 + + else: + + # reformat data resolution at which we blur + if thickness is not None: + down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))([downsample_res, thickness]) + else: + down_res = downsample_res + + # get std deviation for blurring kernels + if mult_coef is None: + sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x, tf.convert_to_tensor(current_res, dtype='float32')), + 0.5, 0.75 * x / tf.convert_to_tensor(current_res, dtype='float32')))(down_res) + else: + sigma = KL.Lambda(lambda x: mult_coef * x / tf.convert_to_tensor(current_res, dtype='float32'))(down_res) + sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.), 0., x[1]))([down_res, sigma]) + + return sigma + + +def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): + """Build gaussian kernels of the specified standard deviation. The outputs are given as tensorflow tensors. + :param sigma: standard deviation of the tensors. Can be given as a list/numpy array or as tensors. In each case, + sigma must have the same length as the number of dimensions of the volume that will be blurred with the output + tensors (e.g. sigma must have 3 values for 3D volumes). + :param max_sigma: + :param blur_range: + :param separable: + :return: + """ + # convert sigma into a tensor + if not tf.is_tensor(sigma): + sigma_tens = tf.convert_to_tensor(utils.reformat_to_list(sigma), dtype='float32') + else: + assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor' + sigma_tens = sigma + shape = sigma_tens.get_shape().as_list() + + # get n_dims and batchsize + if shape[0] is not None: + n_dims = shape[0] + batchsize = None + else: + n_dims = shape[1] + batchsize = tf.split(tf.shape(sigma_tens), [1, -1])[0] + + # reformat max_sigma + if max_sigma is not None: # dynamic blurring + max_sigma = np.array(utils.reformat_to_list(max_sigma, length=n_dims)) + else: # sigma is fixed + max_sigma = np.array(utils.reformat_to_list(sigma, length=n_dims)) + + # randomise the burring std dev and/or split it between dimensions + if blur_range is not None: + if blur_range != 1: + sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range) + + # get size of blurring kernels + windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1 + + if separable: + + split_sigma = tf.split(sigma_tens, [1] * n_dims, axis=-1) + + kernels = list() + comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) + for (i, wsize) in enumerate(windowsize): + + if wsize > 1: + + # build meshgrid and replicate it along batch dim if dynamic blurring + locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2 + if batchsize is not None: + locations = tf.tile(tf.expand_dims(locations, axis=0), + tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')], + axis=0)) + comb[i] += 1 + + # compute gaussians + exp_term = -K.square(locations) / (2 * split_sigma[i] ** 2) + g = tf.exp(exp_term - tf.math.log(np.sqrt(2 * np.pi) * split_sigma[i])) + g = g / tf.reduce_sum(g) + + for axis in comb[i]: + g = tf.expand_dims(g, axis=axis) + kernels.append(tf.expand_dims(tf.expand_dims(g, -1), -1)) + + else: + kernels.append(None) + + else: + + # build meshgrid + mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')] + diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1) + + # replicate meshgrid to batch size and reshape sigma_tens + if batchsize is not None: + diff = tf.tile(tf.expand_dims(diff, axis=0), + tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0)) + for i in range(n_dims): + sigma_tens = tf.expand_dims(sigma_tens, axis=1) + else: + for i in range(n_dims): + sigma_tens = tf.expand_dims(sigma_tens, axis=0) + + # compute gaussians + sigma_is_0 = tf.equal(sigma_tens, 0) + exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2) + norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens)) + kernels = K.sum(norms, -1) + kernels = tf.exp(kernels) + kernels /= tf.reduce_sum(kernels) + kernels = tf.expand_dims(tf.expand_dims(kernels, -1), -1) + + return kernels + + +def sobel_kernels(n_dims): + """Returns sobel kernels to compute spatial derivative on image of n dimensions.""" + + in_dir = tf.convert_to_tensor([1, 0, -1], dtype='float32') + orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype='float32') + comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) + + list_kernels = list() + for dim in range(n_dims): + + sublist_kernels = list() + for axis in range(n_dims): + + kernel = in_dir if axis == dim else orthogonal_dir + for i in comb[axis]: + kernel = tf.expand_dims(kernel, axis=i) + sublist_kernels.append(tf.expand_dims(tf.expand_dims(kernel, -1), -1)) + + list_kernels.append(sublist_kernels) + + return list_kernels + + +def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None): + """Build kernel with values of 1 for voxel at a distance < dist_threshold from the center, and 0 otherwise. + The outputs are given as tensorflow tensors. + :param dist_threshold: maximum distance from the center until voxel will have a value of 1. Can be a tensor of size + (batch_size, 1), or a float. + :param n_dims: dimension of the kernel to return (excluding batch and channel dimensions). + :param max_dist_threshold: if distance_threshold is a tensor, max_dist_threshold must be given. It represents the + maximum value that will be passed to dist_threshold. Must be a float. + """ + + # convert dist_threshold into a tensor + if not tf.is_tensor(dist_threshold): + dist_threshold_tens = tf.convert_to_tensor(utils.reformat_to_list(dist_threshold), dtype='float32') + else: + assert max_dist_threshold is not None, 'max_sigma must be provided when dist_threshold is given as a tensor' + dist_threshold_tens = tf.cast(dist_threshold, 'float32') + shape = dist_threshold_tens.get_shape().as_list() + + # get batchsize + batchsize = None if shape[0] is not None else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0] + + # set max_dist_threshold into an array + if max_dist_threshold is None: # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch) + max_dist_threshold = dist_threshold + + # get size of blurring kernels + windowsize = np.array([max_dist_threshold * 2 + 1]*n_dims, dtype='int32') + + # build tensor representing the distance from the centre + mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')] + dist = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1) + dist = tf.sqrt(tf.reduce_sum(tf.square(dist), axis=-1)) + + # replicate distance to batch size and reshape sigma_tens + if batchsize is not None: + dist = tf.tile(tf.expand_dims(dist, axis=0), + tf.concat([batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype='int32')], axis=0)) + for i in range(n_dims - 1): + dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=1) + else: + for i in range(n_dims - 1): + dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=0) + + # build final kernel by thresholding distance tensor + kernel = tf.where(tf.less_equal(dist, dist_threshold_tens), tf.ones_like(dist), tf.zeros_like(dist)) + kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) + + return kernel + + +def resample_tensor(tensor, + resample_shape, + interp_method='linear', + subsample_res=None, + volume_res=None, + build_reliability_map=False): + """This function resamples a volume to resample_shape. It does not apply any pre-filtering. + A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be + specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which + slices were interpolated during resampling from the downsampled to final tensor. + :param tensor: tensor + :param resample_shape: list or numpy array of size (n_dims,) + :param interp_method: (optional) interpolation method for resampling, 'linear' (default) or 'nearest' + :param subsample_res: (optional) if not None, this triggers a downsampling of the volume, prior to the resampling + step. List or numpy array of size (n_dims,). Default si None. + :param volume_res: (optional) if subsample_res is not None, this should be provided to compute downsampling ratio. + list or numpy array of size (n_dims,). Default is None. + :param build_reliability_map: whether to return reliability map along with the resampled tensor. This map indicates + which slices of the resampled tensor are interpolated (0=interpolated, 1=real slice, in between=degree of realness). + :return: resampled volume, with reliability map if necessary. + """ + + # reformat resolutions to lists + subsample_res = utils.reformat_to_list(subsample_res) + volume_res = utils.reformat_to_list(volume_res) + n_dims = len(resample_shape) + + # downsample image + tensor_shape = tensor.get_shape().as_list()[1:-1] + downsample_shape = tensor_shape # will be modified if we actually downsample + + if subsample_res is not None: + assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.' + assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \ + 'had {0}, and {1}'.format(len(subsample_res), len(volume_res)) + if subsample_res != volume_res: + + # get shape at which we downsample + downsample_shape = [int(tensor_shape[i] * volume_res[i] / subsample_res[i]) for i in range(n_dims)] + + # downsample volume + tensor._keras_shape = tuple(tensor.get_shape().as_list()) + tensor = nrn_layers.Resize(size=downsample_shape, interp_method='nearest')(tensor) + + # resample image at target resolution + if resample_shape != downsample_shape: # if we didn't downsample downsample_shape = tensor_shape + tensor._keras_shape = tuple(tensor.get_shape().as_list()) + tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(tensor) + + # compute reliability maps if necessary and return results + if build_reliability_map: + + # compute maps only if we downsampled + if downsample_shape != tensor_shape: + + # compute upsampling factors + upsampling_factors = np.array(resample_shape) / np.array(downsample_shape) + + # build reliability map + reliability_map = 1 + for i in range(n_dims): + loc_float = np.arange(0, resample_shape[i], upsampling_factors[i]) + loc_floor = np.int32(np.floor(loc_float)) + loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1)) + tmp_reliability_map = np.zeros(resample_shape[i]) + tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor) + tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (loc_float - loc_floor) + shape = [1, 1, 1] + shape[i] = resample_shape[i] + reliability_map = reliability_map * np.reshape(tmp_reliability_map, shape) + shape = KL.Lambda(lambda x: tf.shape(x))(tensor) + mask = KL.Lambda(lambda x: tf.reshape(tf.convert_to_tensor(reliability_map, dtype='float32'), + shape=x))(shape) + + # otherwise just return an all-one tensor + else: + mask = KL.Lambda(lambda x: tf.ones_like(x))(tensor) + + return tensor, mask + + else: + return tensor + + +def expand_dims(tensor, axis=0): + """Expand the dimensions of the input tensor along the provided axes (given as an integer or a list).""" + axis = utils.reformat_to_list(axis) + for ax in axis: + tensor = tf.expand_dims(tensor, axis=ax) + return tensor diff --git a/nobrainer/ext/lab2im/edit_volumes.py b/nobrainer/ext/lab2im/edit_volumes.py new file mode 100644 index 00000000..c8e388a5 --- /dev/null +++ b/nobrainer/ext/lab2im/edit_volumes.py @@ -0,0 +1,2836 @@ +""" +This file contains functions to edit/preprocess volumes (i.e. not tensors!). +These functions are sorted in five categories: +1- volume editing: this can be applied to any volume (i.e. images or label maps). It contains: + -mask_volume + -rescale_volume + -crop_volume + -crop_volume_around_region + -crop_volume_with_idx + -pad_volume + -flip_volume + -resample_volume + -resample_volume_like + -get_ras_axes + -align_volume_to_ref + -blur_volume +2- label map editing: can be applied to label maps only. It contains: + -correct_label_map + -mask_label_map + -smooth_label_map + -erode_label_map + -get_largest_connected_component + -compute_hard_volumes + -compute_distance_map +3- editing all volumes in a folder: functions are more or less the same as 1, but they now apply to all the volumes +in a given folder. Thus we provide folder paths rather than numpy arrays as inputs. It contains: + -mask_images_in_dir + -rescale_images_in_dir + -crop_images_in_dir + -crop_images_around_region_in_dir + -pad_images_in_dir + -flip_images_in_dir + -align_images_in_dir + -correct_nans_images_in_dir + -blur_images_in_dir + -create_mutlimodal_images + -convert_images_in_dir_to_nifty + -mri_convert_images_in_dir + -samseg_images_in_dir + -niftyreg_images_in_dir + -upsample_anisotropic_images + -simulate_upsampled_anisotropic_images + -check_images_in_dir +4- label maps in dir: same as 3 but for label map-specific functions. It contains: + -correct_labels_in_dir + -mask_labels_in_dir + -smooth_labels_in_dir + -erode_labels_in_dir + -upsample_labels_in_dir + -compute_hard_volumes_in_dir + -build_atlas +5- dataset editing: functions for editing datasets (i.e. images with corresponding label maps). It contains: + -check_images_and_labels + -crop_dataset_to_minimum_size + -subdivide_dataset_to_patches + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import os +import csv +import shutil +import numpy as np +import tensorflow as tf +import keras.layers as KL +from keras.models import Model +from scipy.ndimage.filters import convolve +from scipy.ndimage import label as scipy_label +from scipy.interpolate import RegularGridInterpolator +from scipy.ndimage.morphology import distance_transform_edt, binary_fill_holes +from scipy.ndimage import binary_dilation, binary_erosion, gaussian_filter + +# project imports +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im.layers import GaussianBlur, ConvertLabels +from nobrainer.ext.lab2im.edit_tensors import blurring_sigma_for_downsampling + + +# ---------------------------------------------------- edit volume ----------------------------------------------------- + +def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, masking_value=0, + return_mask=False, return_copy=True): + """Mask a volume, either with a given mask, or by keeping only the values above a threshold. + :param volume: a numpy array, possibly with several channels + :param mask: (optional) a numpy array to mask volume with. + Mask doesn't have to be a 0/1 array, all strictly positive values of mask are considered for masking volume. + Mask should have the same size as volume. If volume has several channels, mask can either be uni- or multi-channel. + In the first case, the same mask is applied to all channels. + :param threshold: (optional) If mask is None, masking is performed by keeping thresholding the input. + :param dilate: (optional) number of voxels by which to dilate the provided or computed mask. + :param erode: (optional) number of voxels by which to erode the provided or computed mask. + :param fill_holes: (optional) whether to fill the holes in the provided or computed mask. + :param masking_value: (optional) masking value + :param return_mask: (optional) whether to return the applied mask + :param return_copy: (optional) whether to return the original volume or a copy. Default is copy. + :return: the masked volume, and the applied mask if return_mask is True. + """ + + # get info + new_volume = volume.copy() if return_copy else volume + vol_shape = list(new_volume.shape) + n_dims, n_channels = utils.get_dims(vol_shape) + + # get mask and erode/dilate it + if mask is None: + mask = new_volume >= threshold + else: + assert list(mask.shape[:n_dims]) == vol_shape[:n_dims], 'mask should have shape {0}, or {1}, had {2}'.format( + vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape)) + mask = mask > 0 + if dilate > 0: + dilate_struct = utils.build_binary_structure(dilate, n_dims) + mask_to_apply = binary_dilation(mask, dilate_struct) + else: + mask_to_apply = mask + if erode > 0: + erode_struct = utils.build_binary_structure(erode, n_dims) + mask_to_apply = binary_erosion(mask_to_apply, erode_struct) + if fill_holes: + mask_to_apply = binary_fill_holes(mask_to_apply) + + # replace values outside of mask by padding_char + if mask_to_apply.shape == new_volume.shape: + new_volume[np.logical_not(mask_to_apply)] = masking_value + else: + new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = masking_value + + if return_mask: + return new_volume, mask_to_apply + else: + return new_volume + + +def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2, max_percentile=98, use_positive_only=False): + """This function linearly rescales a volume between new_min and new_max. + :param volume: a numpy array + :param new_min: (optional) minimum value for the rescaled image. + :param new_max: (optional) maximum value for the rescaled image. + :param min_percentile: (optional) percentile for estimating robust minimum of volume (float in [0,...100]), + where 0 = np.min + :param max_percentile: (optional) percentile for estimating robust maximum of volume (float in [0,...100]), + where 100 = np.max + :param use_positive_only: (optional) whether to use only positive values when estimating the min and max percentile + :return: rescaled volume + """ + + # select only positive intensities + new_volume = volume.copy() + intensities = new_volume[new_volume > 0] if use_positive_only else new_volume.flatten() + + # define min and max intensities in original image for normalisation + robust_min = np.min(intensities) if min_percentile == 0 else np.percentile(intensities, min_percentile) + robust_max = np.max(intensities) if max_percentile == 100 else np.percentile(intensities, max_percentile) + + # trim values outside range + new_volume = np.clip(new_volume, robust_min, robust_max) + + # rescale image + if robust_min != robust_max: + return new_min + (new_volume - robust_min) / (robust_max - robust_min) * (new_max - new_min) + else: # avoid dividing by zero + return np.zeros_like(new_volume) + + +def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, return_crop_idx=False, mode='center'): + """Crop volume by a given margin, or to a given shape. + :param volume: 2d or 3d numpy array (possibly with multiple channels) + :param cropping_margin: (optional) margin by which to crop the volume. The cropping margin is applied on both sides. + Can be an int, sequence or 1d numpy array of size n_dims. Should be given if cropping_shape is None. + :param cropping_shape: (optional) shape to which the volume will be cropped. Can be an int, sequence or 1d numpy + array of size n_dims. Should be given if cropping_margin is None. + :param aff: (optional) affine matrix of the input volume. + If not None, this function also returns an updated version of the affine matrix for the cropped volume. + :param return_crop_idx: (optional) whether to return the cropping indices used to crop the given volume. + :param mode: (optional) if cropping_shape is not None, whether to extract the centre of the image (mode='center'), + or to randomly crop the volume to the provided shape (mode='random'). Default is 'center'. + :return: cropped volume, corresponding affine matrix if aff is not None, and cropping indices if return_crop_idx is + True (in that order). + """ + + assert (cropping_margin is not None) | (cropping_shape is not None), \ + 'cropping_margin or cropping_shape should be provided' + assert not ((cropping_margin is not None) & (cropping_shape is not None)), \ + 'only one of cropping_margin or cropping_shape should be provided' + + # get info + new_volume = volume.copy() + vol_shape = new_volume.shape + n_dims, _ = utils.get_dims(vol_shape) + + # find cropping indices + if cropping_margin is not None: + cropping_margin = utils.reformat_to_list(cropping_margin, length=n_dims) + do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin) + min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)] + max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)] + else: + cropping_shape = utils.reformat_to_list(cropping_shape, length=n_dims) + if mode == 'center': + min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0) + max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)], + np.array(vol_shape)[:n_dims]) + elif mode == 'random': + crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0) + min_crop_idx = np.random.randint(0, high=crop_max_val + 1) + max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims]) + else: + raise ValueError('mode should be either "center" or "random", had %s' % mode) + crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)]) + + # crop volume + if n_dims == 2: + new_volume = new_volume[crop_idx[0]: crop_idx[2], crop_idx[1]: crop_idx[3], ...] + elif n_dims == 3: + new_volume = new_volume[crop_idx[0]: crop_idx[3], crop_idx[1]: crop_idx[4], crop_idx[2]: crop_idx[5], ...] + + # sort outputs + output = [new_volume] + if aff is not None: + aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ np.array(min_crop_idx) + output.append(aff) + if return_crop_idx: + output.append(crop_idx) + return output[0] if len(output) == 1 else tuple(output) + + +def crop_volume_around_region(volume, + mask=None, + masking_labels=None, + threshold=0.1, + margin=0, + cropping_shape=None, + cropping_shape_div_by=None, + aff=None, + overflow='strict'): + """Crop a volume around a specific region. + This region is defined by a mask obtained by either: + 1) directly specifying it as input (see mask) + 2) keeping a set of label values if the volume is a label map (see masking_labels). + 3) thresholding the input volume (see threshold) + The cropping region is defined by the bounding box of the mask, which we can further modify by either: + 1) extending it by a margin (see margin) + 2) providing a specific cropping shape, in this case the cropping region will be centered around the bounding box + (see cropping_shape). + 3) extending it to a shape that is divisible by a given number. Again, the cropping region will be centered around + the bounding box (see cropping_shape_div_by). + Finally, if the size of the cropping region has been modified, and that this modified size overflows out of the + image (e.g. because the center of the mask is close to the edge), we can either: + 1) stick to the valid image space (the size of the modified cropping region won't be respected) + 2) shift the cropping region so that it lies on the valid image space, and if it still overflows, then we restrict + to the valid image space. + 3) pad the image with zeros, such that the cropping region is not ill-defined anymore. + 3) shift the cropping region to the valida image space, and if it still overflows, then we pad with zeros. + :param volume: a 2d or 3d numpy array + :param mask: (optional) mask of region to crop around. Must be same size as volume. Can either be boolean or 0/1. + If no mask is given, it will be computed by either thresholding the input volume or using masking_labels. + :param masking_labels: (optional) if mask is None, and if the volume is a label map, it can be cropped around a + set of labels specified in masking_labels, which can either be a single int, a sequence or a 1d numpy array. + :param threshold: (optional) if mask amd masking_labels are None, lower bound to determine values to crop around. + :param margin: (optional) add margin around mask + :param cropping_shape: (optional) shape to which the input volumes must be cropped. Volumes are padded around the + centre of the above-defined mask is they are too small for the given shape. Can be an integer or sequence. + Cannot be given at the same time as margin or cropping_shape_div_by. + :param cropping_shape_div_by: (optional) makes sure the shape of the cropped region is divisible by the provided + number. If it is not, then we enlarge the cropping area. If the enlarged area is too big for the input volume, we + pad it with 0. Must be an integer. Cannot be given at the same time as margin or cropping_shape. + :param aff: (optional) if specified, this function returns an updated affine matrix of the volume after cropping. + :param overflow: (optional) how to proceed when the cropping region overflows outside the initial image space. + Can either be 'strict' (default), 'shift-strict', 'padding', 'shift-padding. + :return: the cropped volume, the cropping indices (in the order [lower_bound_dim_1, ..., upper_bound_dim_1, ...]), + and the updated affine matrix if aff is not None. + """ + + assert not ((margin > 0) & (cropping_shape is not None)), "margin and cropping_shape can't be given together." + assert not ((margin > 0) & (cropping_shape_div_by is not None)), \ + "margin and cropping_shape_div_by can't be given together." + assert not ((cropping_shape_div_by is not None) & (cropping_shape is not None)), \ + "cropping_shape_div_by and cropping_shape can't be given together." + + new_vol = volume.copy() + n_dims, n_channels = utils.get_dims(new_vol.shape) + vol_shape = np.array(new_vol.shape[:n_dims]) + + # mask ROIs for cropping + if mask is None: + if masking_labels is not None: + _, mask = mask_label_map(new_vol, masking_values=masking_labels, return_mask=True) + else: + mask = new_vol > threshold + + # find cropping indices + if np.any(mask): + + indices = np.nonzero(mask) + min_idx = np.array([np.min(idx) for idx in indices]) + max_idx = np.array([np.max(idx) for idx in indices]) + intermediate_vol_shape = max_idx - min_idx + + if (margin == 0) & (cropping_shape is None) & (cropping_shape_div_by is None): + cropping_shape = intermediate_vol_shape + if margin: + cropping_shape = intermediate_vol_shape + 2 * margin + elif cropping_shape is not None: + cropping_shape = np.array(utils.reformat_to_list(cropping_shape, length=n_dims)) + elif cropping_shape_div_by is not None: + cropping_shape = [utils.find_closest_number_divisible_by_m(s, cropping_shape_div_by, answer_type='higher') + for s in intermediate_vol_shape] + + min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2)) + max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2)) + min_overflow = np.abs(np.minimum(min_idx, 0)) + max_overflow = np.maximum(max_idx - vol_shape, 0) + + if 'strict' in overflow: + min_overflow = np.zeros_like(min_overflow) + max_overflow = np.zeros_like(min_overflow) + + if overflow == 'shift-strict': + min_idx -= max_overflow + max_idx += min_overflow + + if overflow == 'shift-padding': + for ii in range(n_dims): + # no need to do anything if both min/max_overflow are 0 (no padding/shifting required at all) + # or if both are positive, because in this case we don't shift at all and we pad directly + if (min_overflow[ii] > 0) & (max_overflow[ii] == 0): + max_idx_new = max_idx[ii] + min_overflow[ii] + if max_idx_new <= vol_shape[ii]: + max_idx[ii] = max_idx_new + min_overflow[ii] = 0 + else: + min_overflow[ii] = min_overflow[ii] - (vol_shape[ii] - max_idx[ii]) + max_idx[ii] = vol_shape[ii] + elif (min_overflow[ii] == 0) & (max_overflow[ii] > 0): + min_idx_new = min_idx[ii] - max_overflow[ii] + if min_idx_new >= 0: + min_idx[ii] = min_idx_new + max_overflow[ii] = 0 + else: + max_overflow[ii] = max_overflow[ii] - min_idx[ii] + min_idx[ii] = 0 + + # crop volume if necessary + min_idx = np.maximum(min_idx, 0) + max_idx = np.minimum(max_idx, vol_shape) + cropping = np.concatenate([min_idx, max_idx]) + if np.any(cropping[:3] > 0) or np.any(cropping[3:] != vol_shape): + if n_dims == 3: + new_vol = new_vol[cropping[0]:cropping[3], cropping[1]:cropping[4], cropping[2]:cropping[5], ...] + elif n_dims == 2: + new_vol = new_vol[cropping[0]:cropping[2], cropping[1]:cropping[3], ...] + else: + raise ValueError('cannot crop volumes with more than 3 dimensions') + + # pad volume if necessary + if np.any(min_overflow > 0) | np.any(max_overflow > 0): + pad_margins = tuple([(min_overflow[i], max_overflow[i]) for i in range(n_dims)]) + pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins + new_vol = np.pad(new_vol, pad_margins, mode='constant', constant_values=0) + + # if there's nothing to crop around, we return the input as is + else: + min_idx = min_overflow = np.zeros(3) + cropping = None + + # return results + if aff is not None: + if n_dims == 2: + min_idx = np.append(min_idx, 0) + min_overflow = np.append(min_overflow, 0) + aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ min_idx + aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_overflow + return new_vol, cropping, aff + else: + return new_vol, cropping + + +def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=True): + """Crop a volume with given indices. + :param volume: a 2d or 3d numpy array + :param crop_idx: cropping indices, in the order [lower_bound_dim_1, ..., upper_bound_dim_1, ...]. + Can be a list or a 1d numpy array. + :param aff: (optional) if aff is specified, this function returns an updated affine matrix of the volume after + cropping. + :param n_dims: (optional) number of dimensions (excluding channels) of the volume. If not provided, n_dims will be + inferred from the input volume. + :param return_copy: (optional) whether to return the original volume or a copy. Default is copy. + :return: the cropped volume, and the updated affine matrix if aff is not None. + """ + + # get info + new_volume = volume.copy() if return_copy else volume + n_dims = int(np.array(crop_idx).shape[0] / 2) if n_dims is None else n_dims + + # crop image + if n_dims == 2: + new_volume = new_volume[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...] + elif n_dims == 3: + new_volume = new_volume[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...] + else: + raise Exception('cannot crop volumes with more than 3 dimensions') + + if aff is not None: + aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3] + return new_volume, aff + else: + return new_volume + + +def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx=False): + """Pad volume to a given shape + :param volume: volume to be padded + :param padding_shape: shape to pad volume to. Can be a number, a sequence or a 1d numpy array. + :param padding_value: (optional) value used for padding + :param aff: (optional) affine matrix of the volume + :param return_pad_idx: (optional) the pad_idx corresponds to the indices where we should crop the resulting + padded image (ie the output of this function) to go back to the original volume (ie the input of this function). + :return: padded volume, and updated affine matrix if aff is not None. + """ + + # get info + new_volume = volume.copy() + vol_shape = new_volume.shape + n_dims, n_channels = utils.get_dims(vol_shape) + padding_shape = utils.reformat_to_list(padding_shape, length=n_dims, dtype='int') + + # check if need to pad + if np.any(np.array(padding_shape, dtype='int32') > np.array(vol_shape[:n_dims], dtype='int32')): + + # get padding margins + min_margins = np.maximum(np.int32(np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0) + max_margins = np.maximum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0) + pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])]) + pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)]) + if n_channels > 1: + pad_margins = tuple(list(pad_margins) + [(0, 0)]) + + # pad volume + new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value) + + if aff is not None: + if n_dims == 2: + min_margins = np.append(min_margins, 0) + aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_margins + + else: + pad_idx = np.concatenate([np.array([0] * n_dims), np.array(vol_shape[:n_dims])]) + + # sort outputs + output = [new_volume] + if aff is not None: + output.append(aff) + if return_pad_idx: + output.append(pad_idx) + return output[0] if len(output) == 1 else tuple(output) + + +def flip_volume(volume, axis=None, direction=None, aff=None, return_copy=True): + """Flip volume along a specified axis. + If unknown, this axis can be inferred from an affine matrix with a specified anatomical direction. + :param volume: a numpy array + :param axis: (optional) axis along which to flip the volume. Can either be an int or a tuple. + :param direction: (optional) if axis is None, the volume can be flipped along an anatomical direction: + 'rl' (right/left), 'ap' anterior/posterior), 'si' (superior/inferior). + :param aff: (optional) please provide an affine matrix if direction is not None + :param return_copy: (optional) whether to return the original volume or a copy. Default is copy. + :return: flipped volume + """ + + new_volume = volume.copy() if return_copy else volume + assert (axis is not None) | ((aff is not None) & (direction is not None)), \ + 'please provide either axis, or an affine matrix with a direction' + + # get flipping axis from aff if axis not provided + if (axis is None) & (aff is not None): + volume_axes = get_ras_axes(aff) + if direction == 'rl': + axis = volume_axes[0] + elif direction == 'ap': + axis = volume_axes[1] + elif direction == 'si': + axis = volume_axes[2] + else: + raise ValueError("direction should be 'rl', 'ap', or 'si', had %s" % direction) + + # flip volume + return np.flip(new_volume, axis=axis) + + +def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True): + """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS + :param volume: a numpy array + :param aff: affine matrix of the volume + :param new_vox_size: new voxel size (3 - element numpy vector) in mm + :param interpolation: (optional) type of interpolation. Can be 'linear' or 'nearest'. Default is 'linear'. + :param blur: (optional) whether to blur before resampling to avoid aliasing effects. + Only used if the input volume is downsampled. Default is True. + :return: new volume and affine matrix + """ + + pixdim = np.sqrt(np.sum(aff * aff, axis=0))[:-1] + new_vox_size = np.array(new_vox_size) + factor = pixdim / new_vox_size + sigmas = 0.25 / factor + sigmas[factor > 1] = 0 # don't blur if upsampling + + volume_filt = gaussian_filter(volume, sigmas) if blur else volume + + # volume2 = zoom(volume_filt, factor, order=1, mode='reflect', prefilter=False) + x = np.arange(0, volume_filt.shape[0]) + y = np.arange(0, volume_filt.shape[1]) + z = np.arange(0, volume_filt.shape[2]) + + my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt, method=interpolation) + + start = - (factor - 1) / (2 * factor) + step = 1.0 / factor + stop = start + step * np.ceil(volume_filt.shape * factor) + + xi = np.arange(start=start[0], stop=stop[0], step=step[0]) + yi = np.arange(start=start[1], stop=stop[1], step=step[1]) + zi = np.arange(start=start[2], stop=stop[2], step=step[2]) + xi[xi < 0] = 0 + yi[yi < 0] = 0 + zi[zi < 0] = 0 + xi[xi > (volume_filt.shape[0] - 1)] = volume_filt.shape[0] - 1 + yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1 + zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1 + + xig, yig, zig = np.meshgrid(xi, yi, zi, indexing='ij', sparse=True) + volume2 = my_interpolating_function((xig, yig, zig)) + + aff2 = aff.copy() + for c in range(3): + aff2[:-1, c] = aff2[:-1, c] / factor[c] + aff2[:-1, -1] = aff2[:-1, -1] - np.matmul(aff2[:-1, :-1], 0.5 * (factor - 1)) + + return volume2, aff2 + + +def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation='linear'): + """This function reslices a floating image to the space of a reference image + :param vol_ref: a numpy array with the reference volume + :param aff_ref: affine matrix of the reference volume + :param vol_flo: a numpy array with the floating volume + :param aff_flo: affine matrix of the floating volume + :param interpolation: (optional) type of interpolation. Can be 'linear' or 'nearest'. Default is 'linear'. + :return: resliced volume + """ + + T = np.matmul(np.linalg.inv(aff_flo), aff_ref) + + xf = np.arange(0, vol_flo.shape[0]) + yf = np.arange(0, vol_flo.shape[1]) + zf = np.arange(0, vol_flo.shape[2]) + + my_interpolating_function = RegularGridInterpolator((xf, yf, zf), vol_flo, bounds_error=False, fill_value=0.0, + method=interpolation) + + xr = np.arange(0, vol_ref.shape[0]) + yr = np.arange(0, vol_ref.shape[1]) + zr = np.arange(0, vol_ref.shape[2]) + + xrg, yrg, zrg = np.meshgrid(xr, yr, zr, indexing='ij', sparse=False) + n = xrg.size + xrg = xrg.reshape([n]) + yrg = yrg.reshape([n]) + zrg = zrg.reshape([n]) + bottom = np.ones_like(xrg) + coords = np.stack([xrg, yrg, zrg, bottom]) + coords_new = np.matmul(T, coords)[:-1, :] + result = my_interpolating_function((coords_new[0, :], coords_new[1, :], coords_new[2, :])) + + return result.reshape(vol_ref.shape) + + +def get_ras_axes(aff, n_dims=3): + """This function finds the RAS axes corresponding to each dimension of a volume, based on its affine matrix. + :param aff: affine matrix Can be a 2d numpy array of size n_dims*n_dims, n_dims+1*n_dims+1, or n_dims*n_dims+1. + :param n_dims: number of dimensions (excluding channels) of the volume corresponding to the provided affine matrix. + :return: two numpy 1d arrays of length n_dims, one with the axes corresponding to RAS orientations, + and one with their corresponding direction. + """ + aff_inverted = np.linalg.inv(aff) + img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0) + for i in range(n_dims): + if i not in img_ras_axes: + unique, counts = np.unique(img_ras_axes, return_counts=True) + incorrect_value = unique[np.argmax(counts)] + img_ras_axes[np.where(img_ras_axes == incorrect_value)[0][-1]] = i + + return img_ras_axes + + +def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True): + """This function aligns a volume to a reference orientation (axis and direction) specified by an affine matrix. + :param volume: a numpy array + :param aff: affine matrix of the floating volume + :param aff_ref: (optional) affine matrix of the target orientation. Default is identity matrix. + :param return_aff: (optional) whether to return the affine matrix of the aligned volume + :param n_dims: (optional) number of dimensions (excluding channels) of the volume. If not provided, n_dims will be + inferred from the input volume. + :param return_copy: (optional) whether to return the original volume or a copy. Default is copy. + :return: aligned volume, with corresponding affine matrix if return_aff is True. + """ + + # work on copy + new_volume = volume.copy() if return_copy else volume + aff_flo = aff.copy() + + # default value for aff_ref + if aff_ref is None: + aff_ref = np.eye(4) + + # extract ras axes + if n_dims is None: + n_dims, _ = utils.get_dims(new_volume.shape) + ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims) + ras_axes_flo = get_ras_axes(aff_flo, n_dims=n_dims) + + # align axes + aff_flo[:, ras_axes_ref] = aff_flo[:, ras_axes_flo] + for i in range(n_dims): + if ras_axes_flo[i] != ras_axes_ref[i]: + new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i]) + swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i]) + ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx] + + # align directions + dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0) + for i in range(n_dims): + if dot_products[i] < 0: + new_volume = np.flip(new_volume, axis=i) + aff_flo[:, i] = - aff_flo[:, i] + aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1) + + if return_aff: + return new_volume, aff_flo + else: + return new_volume + + +def blur_volume(volume, sigma, mask=None): + """Blur volume with gaussian masks of given sigma. + :param volume: 2d or 3d numpy array + :param sigma: standard deviation of the gaussian kernels. Can be a number, a sequence or a 1d numpy array + :param mask: (optional) numpy array of the same shape as volume to correct for edge blurring effects. + Mask can be a boolean or numerical array. In the latter, the mask is computed by keeping all values above zero. + :return: blurred volume + """ + + # initialisation + new_volume = volume.copy() + n_dims, _ = utils.get_dims(new_volume.shape) + sigma = utils.reformat_to_list(sigma, length=n_dims, dtype='float') + + # blur image + new_volume = gaussian_filter(new_volume, sigma=sigma, mode='nearest') # nearest refers to edge padding + + # correct edge effect if mask is not None + if mask is not None: + assert new_volume.shape == mask.shape, 'volume and mask should have the same dimensions: ' \ + 'got {0} and {1}'.format(new_volume.shape, mask.shape) + mask = (mask > 0) * 1.0 + blurred_mask = gaussian_filter(mask, sigma=sigma, mode='nearest') + new_volume = new_volume / (blurred_mask + 1e-6) + new_volume[mask == 0] = 0 + + return new_volume + + +# --------------------------------------------------- edit label map --------------------------------------------------- + +def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, use_nearest_label=False, + remove_zero=False, smooth=False): + """This function corrects specified label values in a label map by either a list of given values, or by the nearest + label. + :param labels: a 2d or 3d label map + :param list_incorrect_labels: list of all label values to correct (eg [1, 2, 3]). Can also be a path to such a list. + :param list_correct_labels: (optional) list of correct label values to replace the incorrect ones. + Correct values must have the same order as their corresponding value in list_incorrect_labels. + When several correct values are possible for the same incorrect value, the nearest correct value will be selected at + each voxel to correct. In that case, the different correct values must be specified inside a list within + list_correct_labels (e.g. [10, 20, 30, [40, 50]). + :param use_nearest_label: (optional) whether to correct the incorrect label values with the nearest labels. + :param remove_zero: (optional) if use_nearest_label is True, set to True not to consider zero among the potential + candidates for the nearest neighbour. -1 will be returned when no solution are possible. + :param smooth: (optional) whether to smooth the corrected label map + :return: corrected label map + """ + + assert (list_correct_labels is not None) | use_nearest_label, \ + 'please provide a list of correct labels, or set use_nearest_label to True.' + assert (list_correct_labels is None) | (not use_nearest_label), \ + 'cannot provide a list of correct values and set use_nearest_label to True' + + # initialisation + new_labels = labels.copy() + list_incorrect_labels = utils.reformat_to_list(utils.load_array_if_path(list_incorrect_labels)) + volume_labels = np.unique(labels) + n_dims, _ = utils.get_dims(labels.shape) + + # use list of correct values + if list_correct_labels is not None: + list_correct_labels = utils.reformat_to_list(utils.load_array_if_path(list_correct_labels)) + + # loop over label values + for incorrect_label, correct_label in zip(list_incorrect_labels, list_correct_labels): + if incorrect_label in volume_labels: + + # only one possible value to replace with + if isinstance(correct_label, (int, float, np.int64, np.int32, np.int16, np.int8)): + incorrect_voxels = np.where(labels == incorrect_label) + new_labels[incorrect_voxels] = correct_label + + # several possibilities + elif isinstance(correct_label, (tuple, list)): + + # make sure at least one correct label is present + if not any([lab in volume_labels for lab in correct_label]): + print('no correct values found in volume, please adjust: ' + 'incorrect: {}, correct: {}'.format(incorrect_label, correct_label)) + + # crop around incorrect label until we find incorrect labels + correct_label_not_found = True + margin_mult = 1 + tmp_labels = None + crop = None + while correct_label_not_found: + tmp_labels, crop = crop_volume_around_region(labels, + masking_labels=incorrect_label, + margin=10 * margin_mult) + correct_label_not_found = not any([lab in np.unique(tmp_labels) for lab in correct_label]) + margin_mult += 1 + + # calculate distance maps for all new label candidates + incorrect_voxels = np.where(tmp_labels == incorrect_label) + distance_map_list = [distance_transform_edt(tmp_labels != lab) for lab in correct_label] + distances_correct = np.stack([dist[incorrect_voxels] for dist in distance_map_list]) + + # select nearest values and use them to correct label map + idx_correct_lab = np.argmin(distances_correct, axis=0) + incorrect_voxels = tuple([incorrect_voxels[i] + crop[i] for i in range(n_dims)]) + new_labels[incorrect_voxels] = np.array(correct_label)[idx_correct_lab] + + # use nearest label + else: + + # loop over label values + for incorrect_label in list_incorrect_labels: + if incorrect_label in volume_labels: + + # loop around regions + components, n_components = scipy_label(labels == incorrect_label) + loop_info = utils.LoopInfo(n_components + 1, 100, 'correcting') + for i in range(1, n_components + 1): + loop_info.update(i) + + # crop each region + _, crop = crop_volume_around_region(components, masking_labels=i, margin=1) + tmp_labels = crop_volume_with_idx(labels, crop) + tmp_new_labels = crop_volume_with_idx(new_labels, crop) + + # list all possible correct labels + correct_labels = np.unique(tmp_labels) + for il in list_incorrect_labels: + correct_labels = np.delete(correct_labels, np.where(correct_labels == il)) + if remove_zero: + correct_labels = np.delete(correct_labels, np.where(correct_labels == 0)) + + # replace incorrect voxels by new value + incorrect_voxels = np.where(tmp_labels == incorrect_label) + if len(correct_labels) == 0: + tmp_new_labels[incorrect_voxels] = -1 + else: + if len(correct_labels) == 1: + idx_correct_lab = np.zeros(len(incorrect_voxels[0]), dtype='int32') + else: + distance_map_list = [distance_transform_edt(tmp_labels != lab) for lab in correct_labels] + distances_correct = np.stack([dist[incorrect_voxels] for dist in distance_map_list]) + idx_correct_lab = np.argmin(distances_correct, axis=0) + tmp_new_labels[incorrect_voxels] = np.array(correct_labels)[idx_correct_lab] + + # paste back + if n_dims == 2: + new_labels[crop[0]:crop[2], crop[1]:crop[3], ...] = tmp_new_labels + else: + new_labels[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] = tmp_new_labels + + # smoothing + if smooth: + kernel = np.ones(tuple([3] * n_dims)) + new_labels = smooth_label_map(new_labels, kernel) + + return new_labels + + +def mask_label_map(labels, masking_values, masking_value=0, return_mask=False): + """ + This function masks a label map around a list of specified values. + :param labels: input label map + :param masking_values: list of values to mask around + :param masking_value: (optional) value to mask the label map with + :param return_mask: (optional) whether to return the applied mask + :return: the masked label map, and the applied mask if return_mask is True. + """ + + # build mask and mask labels + mask = np.zeros(labels.shape, dtype=bool) + masked_labels = labels.copy() + for value in utils.reformat_to_list(masking_values): + mask = mask | (labels == value) + masked_labels[np.logical_not(mask)] = masking_value + + if return_mask: + mask = mask * 1 + return masked_labels, mask + else: + return masked_labels + + +def smooth_label_map(labels, kernel, labels_list=None, print_progress=0): + """This function smooth an input label map by replacing each voxel by the value of its most numerous neighbour. + :param labels: input label map + :param kernel: kernel when counting neighbours. Must contain only zeros or ones. + :param labels_list: list of label values to smooth. Defaults is None, where all labels are smoothed. + :param print_progress: (optional) If not 0, interval at which to print the number of processed labels. + :return: smoothed label map + """ + # get info + labels_shape = labels.shape + unique_labels = np.unique(labels).astype('int32') + if labels_list is None: + labels_list = unique_labels + new_labels = mask_new_labels = None + else: + labels_to_keep = [lab for lab in unique_labels if lab not in labels_list] + new_labels, mask_new_labels = mask_label_map(labels, labels_to_keep, return_mask=True) + + # loop through label values + count = np.zeros(labels_shape) + labels_smoothed = np.zeros(labels_shape, dtype='int') + loop_info = utils.LoopInfo(len(labels_list), print_progress, 'smoothing') + for la, label in enumerate(labels_list): + if print_progress: + loop_info.update(la) + + # count neighbours with same value + mask = (labels == label) * 1 + n_neighbours = convolve(mask, kernel) + + # update label map and maximum neighbour counts + idx = n_neighbours > count + count[idx] = n_neighbours[idx] + labels_smoothed[idx] = label + labels_smoothed = labels_smoothed.astype('int32') + + if new_labels is None: + new_labels = labels_smoothed + else: + new_labels = np.where(mask_new_labels, new_labels, labels_smoothed) + + return new_labels + + +def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, model=None, return_model=False): + """Erode a given set of label values within a label map. + :param labels: a 2d or 3d label map + :param labels_to_erode: list of label values to erode + :param erosion_factors: (optional) list of erosion factors to use for each label. If values are integers, normal + erosion applies. If float, we first 1) blur a mask of the corresponding label value, and 2) use the erosion factor + as a threshold in the blurred mask. + If erosion_factors is a single value, the same factor will be applied to all labels. + :param gpu: (optional) whether to use a fast gpu model for blurring (if erosion factors are floats) + :param model: (optional) gpu model for blurring masks (if erosion factors are floats) + :param return_model: (optional) whether to return the gpu blurring model + :return: eroded label map, and gpu blurring model is return_model is True. + """ + # reformat labels_to_erode and erode + new_labels = labels.copy() + labels_to_erode = utils.reformat_to_list(labels_to_erode) + erosion_factors = utils.reformat_to_list(erosion_factors, length=len(labels_to_erode)) + labels_shape = list(new_labels.shape) + n_dims, _ = utils.get_dims(labels_shape) + + # loop over labels to erode + for label_to_erode, erosion_factor in zip(labels_to_erode, erosion_factors): + + assert erosion_factor > 0, 'all erosion factors should be strictly positive, had {}'.format(erosion_factor) + + # get mask of current label value + mask = (new_labels == label_to_erode) + + # erode as usual if erosion factor is int + if int(erosion_factor) == erosion_factor: + erode_struct = utils.build_binary_structure(int(erosion_factor), n_dims) + eroded_mask = binary_erosion(mask, erode_struct) + + # blur mask and use erosion factor as a threshold if float + else: + if gpu: + if model is None: + mask_in = KL.Input(shape=labels_shape + [1], dtype='float32') + blurred_mask = GaussianBlur([1] * 3)(mask_in) + model = Model(inputs=mask_in, outputs=blurred_mask) + eroded_mask = model.predict(utils.add_axis(np.float32(mask), axis=[0, -1])) + else: + eroded_mask = blur_volume(np.array(mask, dtype='float32'), 1) + eroded_mask = np.squeeze(eroded_mask) > erosion_factor + + # crop label map and mask around values to change + mask = mask & np.logical_not(eroded_mask) + cropped_lab_mask, cropping = crop_volume_around_region(mask, margin=3) + cropped_labels = crop_volume_with_idx(new_labels, cropping) + + # calculate distance maps for all labels in cropped_labels + labels_list = np.unique(cropped_labels) + labels_list = labels_list[labels_list != label_to_erode] + list_dist_maps = [distance_transform_edt(np.logical_not(cropped_labels == la)) for la in labels_list] + candidate_distances = np.stack([dist[cropped_lab_mask] for dist in list_dist_maps]) + + # select nearest value and put cropped labels back to full label map + idx_correct_lab = np.argmin(candidate_distances, axis=0) + cropped_labels[cropped_lab_mask] = np.array(labels_list)[idx_correct_lab] + if n_dims == 2: + new_labels[cropping[0]:cropping[2], cropping[1]:cropping[3], ...] = cropped_labels + elif n_dims == 3: + new_labels[cropping[0]:cropping[3], cropping[1]:cropping[4], cropping[2]:cropping[5], ...] = cropped_labels + + if return_model: + return new_labels, model + else: + return new_labels + + +def get_largest_connected_component(mask, structure=None): + """Function to get the largest connected component for a given input. + :param mask: a 2d or 3d label map of boolean type. + :param structure: numpy array defining the connectivity. + """ + components, n_components = scipy_label(mask, structure) + return components == np.argmax(np.bincount(components.flat)[1:]) + 1 if n_components > 0 else mask.copy() + + +def compute_hard_volumes(labels, voxel_volume=1., label_list=None, skip_background=True): + """Compute hard volumes in a label map. + :param labels: a label map + :param voxel_volume: (optional) volume of voxel. Default is 1 (i.e. returned volumes are voxel counts). + :param label_list: (optional) list of labels to compute volumes for. Can be an int, a sequence, or a numpy array. + If None, the volumes of all label values are computed. + :param skip_background: (optional) whether to skip computing the volume of the background. + If label_list is None, this assumes background value is 0. + If label_list is not None, this assumes the background is the first value in label list. + :return: numpy 1d vector with the volumes of each structure + """ + + # initialisation + subject_label_list = utils.reformat_to_list(np.unique(labels), dtype='int') + if label_list is None: + label_list = subject_label_list + else: + label_list = utils.reformat_to_list(label_list) + if skip_background: + label_list = label_list[1:] + volumes = np.zeros(len(label_list)) + + # loop over label values + for idx, label in enumerate(label_list): + if label in subject_label_list: + mask = (labels == label) * 1 + volumes[idx] = np.sum(mask) + else: + volumes[idx] = 0 + + return volumes * voxel_volume + + +def compute_distance_map(labels, masking_labels=None, crop_margin=None): + """Compute distance map for a given list of label values in a label map. + :param labels: a label map + :param masking_labels: (optional) list of label values to mask the label map with. The distances will be computed + for these labels only. Default is None, where all positive values are considered. + :param crop_margin: (optional) margin with which to crop the input label maps around the labels for which we + want to compute the distance maps. + :return: a distance map with positive values inside the considered regions, and negative values outside.""" + + n_dims, _ = utils.get_dims(labels.shape) + + # crop label map if necessary + if crop_margin is not None: + tmp_labels, crop_idx = crop_volume_around_region(labels, margin=crop_margin) + else: + tmp_labels = labels + crop_idx = None + + # mask label map around specify values + if masking_labels is not None: + masking_labels = utils.reformat_to_list(masking_labels) + mask = np.zeros(tmp_labels.shape, dtype='bool') + for masking_label in masking_labels: + mask = mask | tmp_labels == masking_label + else: + mask = tmp_labels > 0 + not_mask = np.logical_not(mask) + + # compute distances + dist_in = distance_transform_edt(mask) + dist_in = np.where(mask, dist_in - 0.5, dist_in) + dist_out = - distance_transform_edt(not_mask) + dist_out = np.where(not_mask, dist_out + 0.5, dist_out) + tmp_dist = dist_in + dist_out + + # put back in original matrix if we cropped + if crop_idx is not None: + dist = np.min(tmp_dist) * np.ones(labels.shape, dtype='float32') + if n_dims == 3: + dist[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...] = tmp_dist + elif n_dims == 2: + dist[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...] = tmp_dist + else: + dist = tmp_dist + + return dist + + +# ------------------------------------------------- edit volumes in dir ------------------------------------------------ + +def mask_images_in_dir(image_dir, result_dir, mask_dir=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, + masking_value=0, write_mask=False, mask_result_dir=None, recompute=True): + """Mask all volumes in a folder, either with masks in a specified folder, or by keeping only the intensity values + above a specified threshold. + :param image_dir: path of directory with images to mask + :param result_dir: path of directory where masked images will be writen + :param mask_dir: (optional) path of directory containing masks. Masks are matched to images by sorting order. + Mask volumes don't have to be boolean or 0/1 arrays as all strictly positive values are used to build the masks. + Masks should have the same size as images. If images are multi-channel, masks can either be uni- or multi-channel. + In the first case, the same mask is applied to all channels. + :param threshold: (optional) If mask is None, masking is performed by keeping thresholding the input. + :param dilate: (optional) number of voxels by which to dilate the provided or computed masks. + :param erode: (optional) number of voxels by which to erode the provided or computed masks. + :param fill_holes: (optional) whether to fill the holes in the provided or computed masks. + :param masking_value: (optional) masking value + :param write_mask: (optional) whether to write the applied masks + :param mask_result_dir: (optional) path of resulting masks, if write_mask is True + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + if mask_result_dir is not None: + utils.mkdir(mask_result_dir) + + # get path masks if necessary + path_images = utils.list_images_in_folder(image_dir) + if mask_dir is not None: + path_masks = utils.list_images_in_folder(mask_dir) + else: + path_masks = [None] * len(path_images) + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, 'masking', True) + for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): + loop_info.update(idx) + + # mask images + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + if path_mask is not None: + mask = utils.load_volume(path_mask) + else: + mask = None + im = mask_volume(im, mask, threshold, dilate, erode, fill_holes, masking_value, write_mask) + + # write mask if necessary + if write_mask: + assert mask_result_dir is not None, 'if write_mask is True, mask_result_dir has to be specified as well' + mask_result_path = os.path.join(mask_result_dir, os.path.basename(path_image)) + utils.save_volume(im[1], aff, h, mask_result_path) + utils.save_volume(im[0], aff, h, path_result) + else: + utils.save_volume(im, aff, h, path_result) + + +def rescale_images_in_dir(image_dir, result_dir, + new_min=0, new_max=255, + min_percentile=2, max_percentile=98, use_positive_only=True, + recompute=True): + """This function linearly rescales all volumes in image_dir between new_min and new_max. + :param image_dir: path of directory with images to rescale + :param result_dir: path of directory where rescaled images will be writen + :param new_min: (optional) minimum value for the rescaled images. + :param new_max: (optional) maximum value for the rescaled images. + :param min_percentile: (optional) percentile for estimating robust minimum of volume (float in [0,...100]), + where 0 = np.min + :param max_percentile: (optional) percentile for estimating robust maximum of volume (float in [0,...100]), + where 100 = np.max + :param use_positive_only: (optional) whether to use only positive values when estimating the min and max percentile + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # loop over images + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, 'rescaling', True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + im = rescale_volume(im, new_min, new_max, min_percentile, max_percentile, use_positive_only) + utils.save_volume(im, aff, h, path_result) + + +def crop_images_in_dir(image_dir, result_dir, cropping_margin=None, cropping_shape=None, recompute=True): + """Crop all volumes in a folder by a given margin, or to a given shape. + :param image_dir: path of directory with images to rescale + :param result_dir: path of directory where cropped images will be writen + :param cropping_margin: (optional) margin by which to crop the volume. + Can be an int, a sequence or a 1d numpy array. Should be given if cropping_shape is None. + :param cropping_shape: (optional) shape to which the volume will be cropped. + Can be an int, a sequence or a 1d numpy array. Should be given if cropping_margin is None. + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # loop over images and masks + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # crop image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + volume, aff, h = utils.load_volume(path_image, im_only=False) + volume, aff = crop_volume(volume, cropping_margin, cropping_shape, aff) + utils.save_volume(volume, aff, h, path_result) + + +def crop_images_around_region_in_dir(image_dir, + result_dir, + mask_dir=None, + threshold=0.1, + masking_labels=None, + crop_margin=5, + recompute=True): + """Crop all volumes in a folder around a region, which is defined for each volume by a mask obtained by either + 1) directly providing it as input + 2) thresholding the input volume + 3) keeping a set of label values if the volume is a label map. + :param image_dir: path of directory with images to crop + :param result_dir: path of directory where cropped images will be writen + :param mask_dir: (optional) path of directory of input masks + :param threshold: (optional) lower bound to determine values to crop around + :param masking_labels: (optional) if the volume is a label map, it can be cropped around a given set of labels by + specifying them in masking_labels, which can either be a single int, a list or a 1d numpy array. + :param crop_margin: (optional) cropping margin + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # list volumes and masks + path_images = utils.list_images_in_folder(image_dir) + if mask_dir is not None: + path_masks = utils.list_images_in_folder(mask_dir) + else: + path_masks = [None] * len(path_images) + + # loop over images and masks + loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): + loop_info.update(idx) + + # crop image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + volume, aff, h = utils.load_volume(path_image, im_only=True) + if path_mask is not None: + mask = utils.load_volume(path_mask) + else: + mask = None + volume, cropping, aff = crop_volume_around_region(volume, mask, threshold, masking_labels, crop_margin, aff) + utils.save_volume(volume, aff, h, path_result) + + +def pad_images_in_dir(image_dir, result_dir, max_shape=None, padding_value=0, recompute=True): + """Pads all the volumes in a folder to the same shape (either provided or computed). + :param image_dir: path of directory with images to pad + :param result_dir: path of directory where padded images will be writen + :param max_shape: (optional) shape to pad the volumes to. Can be an int, a sequence or a 1d numpy array. + If None, volumes will be padded to the shape of the biggest volume in image_dir. + :param padding_value: (optional) value to pad the volumes with. + :param recompute: (optional) whether to recompute result files even if they already exist + :return: shape of the padded volumes. + """ + + # create result dir + utils.mkdir(result_dir) + + # list labels + path_images = utils.list_images_in_folder(image_dir) + + # get maximum shape + if max_shape is None: + max_shape, aff, _, _, h, _ = utils.get_volume_info(path_images[0]) + for path_image in path_images[1:]: + image_shape, aff, _, _, h, _ = utils.get_volume_info(path_image) + max_shape = tuple(np.maximum(np.asarray(max_shape), np.asarray(image_shape))) + max_shape = np.array(max_shape) + + # loop over label maps + loop_info = utils.LoopInfo(len(path_images), 10, 'padding', True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # pad map + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + im, aff = pad_volume(im, max_shape, padding_value, aff) + utils.save_volume(im, aff, h, path_result) + + return max_shape + + +def flip_images_in_dir(image_dir, result_dir, axis=None, direction=None, recompute=True): + """Flip all images in a directory along a specified axis. + If unknown, this axis can be replaced by an anatomical direction. + :param image_dir: path of directory with images to flip + :param result_dir: path of directory where flipped images will be writen + :param axis: (optional) axis along which to flip the volume + :param direction: (optional) if axis is None, the volume can be flipped along an anatomical direction: + 'rl' (right/left), 'ap' (anterior/posterior), 'si' (superior/inferior). + :param recompute: (optional) whether to recompute result files even if they already exists + """ + # create result dir + utils.mkdir(result_dir) + + # loop over images + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, 'flipping', True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # flip image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + im = flip_volume(im, axis=axis, direction=direction, aff=aff) + utils.save_volume(im, aff, h, path_result) + + +def align_images_in_dir(image_dir, result_dir, aff_ref=None, path_ref=None, recompute=True): + """This function aligns all images in image_dir to a reference orientation (axes and directions). + This reference orientation can be directly provided as an affine matrix, or can be specified by a reference volume. + If neither are provided, the reference orientation is assumed to be an identity matrix. + :param image_dir: path of directory with images to align + :param result_dir: path of directory where flipped images will be writen + :param aff_ref: (optional) reference affine matrix. Can be a numpy array, or the path to such array. + :param path_ref: (optional) path of a volume to which all images will be aligned. Can also be the path to a folder + with as many images as in image_dir, in which case each image in image_dir is aligned to its counterpart in path_ref + (they are matched by sorting order). + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + path_images = utils.list_images_in_folder(image_dir) + + # read reference affine matrix + if path_ref is not None: + assert aff_ref is None, 'cannot provide aff_ref and path_ref together.' + basename = os.path.basename(path_ref) + if ('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename): + _, aff_ref, _ = utils.load_volume(path_ref, im_only=False) + path_refs = [None] * len(path_images) + else: + path_refs = utils.list_images_in_folder(path_ref) + elif aff_ref is not None: + aff_ref = utils.load_array_if_path(aff_ref) + path_refs = [None] * len(path_images) + else: + aff_ref = np.eye(4) + path_refs = [None] * len(path_images) + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, 'aligning', True) + for idx, (path_image, path_ref) in enumerate(zip(path_images, path_refs)): + loop_info.update(idx) + + # align image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + if path_ref is not None: + _, aff_ref, _ = utils.load_volume(path_ref, im_only=False) + im, aff = align_volume_to_ref(im, aff, aff_ref=aff_ref, return_aff=True) + utils.save_volume(im, aff, h, path_result) + + +def correct_nans_images_in_dir(image_dir, result_dir, recompute=True): + """Correct NaNs in all images in a directory. + :param image_dir: path of directory with images to correct + :param result_dir: path of directory where corrected images will be writen + :param recompute: (optional) whether to recompute result files even if they already exists + """ + # create result dir + utils.mkdir(result_dir) + + # loop over images + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, 'correcting', True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # flip image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + im[np.isnan(im)] = 0 + utils.save_volume(im, aff, h, path_result) + + +def blur_images_in_dir(image_dir, result_dir, sigma, mask_dir=None, gpu=False, recompute=True): + """This function blurs all the images in image_dir with kernels of the specified std deviations. + :param image_dir: path of directory with images to blur + :param result_dir: path of directory where blurred images will be writen + :param sigma: standard deviation of the blurring gaussian kernels. + Can be a number (isotropic blurring), or a sequence with the same length as the number of dimensions of images. + :param mask_dir: (optional) path of directory with masks of the region to blur. + Images and masks are matched by sorting order. + :param gpu: (optional) whether to use a fast gpu model for blurring + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # list images and masks + path_images = utils.list_images_in_folder(image_dir) + if mask_dir is not None: + path_masks = utils.list_images_in_folder(mask_dir) + else: + path_masks = [None] * len(path_images) + + # loop over images + previous_model_input_shape = None + model = None + loop_info = utils.LoopInfo(len(path_images), 10, 'blurring', True) + for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): + loop_info.update(idx) + + # load image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + im, im_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_image, return_volume=True) + if path_mask is not None: + mask = utils.load_volume(path_mask) + assert mask.shape == im.shape, 'mask and image should have the same shape' + else: + mask = None + + # blur image + if gpu: + if (im_shape != previous_model_input_shape) | (model is None): + previous_model_input_shape = im_shape + inputs = [KL.Input(shape=im_shape + [1])] + sigma = utils.reformat_to_list(sigma, length=n_dims) + if mask is None: + image = GaussianBlur(sigma=sigma)(inputs[0]) + else: + inputs.append(KL.Input(shape=im_shape + [1], dtype='float32')) + image = GaussianBlur(sigma=sigma, use_mask=True)(inputs) + model = Model(inputs=inputs, outputs=image) + if mask is None: + im = np.squeeze(model.predict(utils.add_axis(im, axis=[0, -1]))) + else: + im = np.squeeze(model.predict([utils.add_axis(im, [0, -1]), utils.add_axis(mask, [0, -1])])) + else: + im = blur_volume(im, sigma, mask=mask) + utils.save_volume(im, aff, h, path_result) + + +def create_mutlimodal_images(list_channel_dir, result_dir, recompute=True): + """This function forms multimodal images by stacking channels located in different folders. + :param list_channel_dir: list of all directories, each containing the same channel for all images. + Channels are matched between folders by sorting order. + :param result_dir: path of directory where multimodal images will be writen + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + assert isinstance(list_channel_dir, (list, tuple)), 'list_channel_dir should be a list or a tuple' + + # gather path of all images for all channels + list_channel_paths = [utils.list_images_in_folder(d) for d in list_channel_dir] + n_images = len(list_channel_paths[0]) + n_channels = len(list_channel_dir) + for channel_paths in list_channel_paths: + if len(channel_paths) != n_images: + raise ValueError('all directories should have the same number of files') + + # loop over images + loop_info = utils.LoopInfo(n_images, 10, 'processing', True) + for idx in range(n_images): + loop_info.update(idx) + + # stack all channels and save multichannel image + path_result = os.path.join(result_dir, os.path.basename(list_channel_paths[0][idx])) + if (not os.path.isfile(path_result)) | recompute: + list_channels = list() + tmp_aff = None + tmp_h = None + for channel_idx in range(n_channels): + tmp_channel, tmp_aff, tmp_h = utils.load_volume(list_channel_paths[channel_idx][idx], im_only=False) + list_channels.append(tmp_channel) + im = np.stack(list_channels, axis=-1) + utils.save_volume(im, tmp_aff, tmp_h, path_result) + + +def convert_images_in_dir_to_nifty(image_dir, result_dir, aff=None, ref_aff_dir=None, recompute=True): + """Converts all images in image_dir to nifty format. + :param image_dir: path of directory with images to convert + :param result_dir: path of directory where converted images will be writen + :param aff: (optional) affine matrix in homogeneous coordinates with which to write the images. + Can also be 'FS' to write images with FreeSurfer typical affine matrix. + :param ref_aff_dir: (optional) alternatively to providing a fixed aff, different affine matrices can be used for + each image in image_dir by matching them to corresponding volumes contained in ref_aff_dir. + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # list images + path_images = utils.list_images_in_folder(image_dir) + if ref_aff_dir is not None: + path_ref_images = utils.list_images_in_folder(ref_aff_dir) + else: + path_ref_images = [None] * len(path_images) + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, 'converting', True) + for idx, (path_image, path_ref) in enumerate(zip(path_images, path_ref_images)): + loop_info.update(idx) + + # convert images to nifty format + path_result = os.path.join(result_dir, os.path.basename(utils.strip_extension(path_image))) + '.nii.gz' + if (not os.path.isfile(path_result)) | recompute: + if utils.get_image_extension(path_image) == 'nii.gz': + shutil.copy2(path_image, path_result) + else: + im, tmp_aff, h = utils.load_volume(path_image, im_only=False) + if aff is not None: + tmp_aff = aff + elif path_ref is not None: + _, tmp_aff, h = utils.load_volume(path_ref, im_only=False) + utils.save_volume(im, tmp_aff, h, path_result) + + +def mri_convert_images_in_dir(image_dir, + result_dir, + interpolation=None, + reference_dir=None, + same_reference=False, + voxsize=None, + path_freesurfer='/usr/local/freesurfer', + mri_convert_path='/usr/local/freesurfer/bin/mri_convert', + recompute=True): + """This function launches mri_convert on all images contained in image_dir, and writes the results in result_dir. + The interpolation type can be specified (i.e. 'nearest'), as well as a folder containing references for resampling. + reference_dir can be the path of a single *image* if same_reference=True. + :param image_dir: path of directory with images to convert + :param result_dir: path of directory where converted images will be writen + :param interpolation: (optional) interpolation type, can be 'inter' (default), 'cubic', 'nearest', 'trilinear' + :param reference_dir: (optional) path of directory with reference images. References are matched to images by + sorting order. If same_reference is false, references and images are matched by sorting order. + This can also be the path to a single image that will be used as reference for all images im image_dir (set + same_reference to true in that case). + :param same_reference: (optional) whether to use a single image as reference for all images to interpolate. + :param voxsize: (optional) resolution at which to resample converted image. Must be a list of length n_dims. + :param path_freesurfer: (optional) path FreeSurfer home + :param mri_convert_path: (optional) path mri_convert binary file + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # set up FreeSurfer + os.environ['FREESURFER_HOME'] = path_freesurfer + os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) + mri_convert = mri_convert_path + ' ' + + # list images + path_images = utils.list_images_in_folder(image_dir) + if reference_dir is not None: + if same_reference: + path_references = [reference_dir] * len(path_images) + else: + path_references = utils.list_images_in_folder(reference_dir) + assert len(path_references) == len(path_images), 'different number of files in image_dir and reference_dir' + else: + path_references = [None] * len(path_images) + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, 'converting', True) + for idx, (path_image, path_reference) in enumerate(zip(path_images, path_references)): + loop_info.update(idx) + + # convert image + path_result = os.path.join(result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_result)) | recompute: + cmd = mri_convert + path_image + ' ' + path_result + ' -odt float' + if interpolation is not None: + cmd += ' -rt ' + interpolation + if reference_dir is not None: + cmd += ' -rl ' + path_reference + if voxsize is not None: + voxsize = utils.reformat_to_list(voxsize, dtype='float') + cmd += ' --voxsize ' + ' '.join([str(np.around(v, 3)) for v in voxsize]) + os.system(cmd) + + +def samseg_images_in_dir(image_dir, + result_dir, + atlas_dir=None, + threads=4, + path_freesurfer='/usr/local/freesurfer', + keep_segm_only=True, + recompute=True): + """This function launches samseg for all images contained in image_dir and writes the results in result_dir. + If keep_segm_only=True, the result segmentation is copied in result_dir and SAMSEG's intermediate result dir is + deleted. + :param image_dir: path of directory with input images + :param result_dir: path of directory where processed images folders (if keep_segm_only is False), + or samseg segmentation (if keep_segm_only is True) will be writen + :param atlas_dir: (optional) path of samseg atlas directory. If None, use samseg default atlas. + :param threads: (optional) number of threads to use + :param path_freesurfer: (optional) path FreeSurfer home + :param keep_segm_only: (optional) whether to keep samseg result folders, or only samseg segmentations. + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # set up FreeSurfer + os.environ['FREESURFER_HOME'] = path_freesurfer + os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) + path_samseg = os.path.join(path_freesurfer, 'bin', 'run_samseg') + + # loop over images + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + for idx, path_image in enumerate(path_images): + loop_info.update(idx) + + # build path_result + path_im_result_dir = os.path.join(result_dir, utils.strip_extension(os.path.basename(path_image))) + path_samseg_result = os.path.join(path_im_result_dir, 'seg.mgz') + if keep_segm_only: + path_result = os.path.join(result_dir, utils.strip_extension(os.path.basename(path_image)) + '_seg.mgz') + else: + path_result = path_samseg_result + + # run samseg + if (not os.path.isfile(path_result)) | recompute: + cmd = utils.mkcmd(path_samseg, '-i', path_image, '-o', path_im_result_dir, '--threads', threads) + if atlas_dir is not None: + cmd = utils.mkcmd(cmd, '-a', atlas_dir) + os.system(cmd) + + # move segmentation to result_dir if necessary + if keep_segm_only: + if os.path.isfile(path_samseg_result): + shutil.move(path_samseg_result, path_result) + if os.path.isdir(path_im_result_dir): + shutil.rmtree(path_im_result_dir) + + +def niftyreg_images_in_dir(image_dir, + reference_dir, + nifty_reg_function='reg_resample', + input_transformation_dir=None, + result_dir=None, + result_transformation_dir=None, + interpolation=None, + same_floating=False, + same_reference=False, + same_transformation=False, + path_nifty_reg='/home/benjamin/Softwares/niftyreg-gpu/build/reg-apps', + recompute=True): + """This function launches one of niftyreg functions (reg_aladin, reg_f3d, reg_resample) on all images contained + in image_dir. + :param image_dir: path of directory with images to register. Can also be a single image, in that case set + same_floating to True. + :param reference_dir: path of directory with reference images. If same_reference is false, references and images are + matched by sorting order. This can also be the path to a single image that will be used as reference for all images + im image_dir (set same_reference to True in that case). + :param nifty_reg_function: (optional) name of the niftyreg function to use. Can be 'reg_aladin', 'reg_f3d', or + 'reg_resample'. Default is 'reg_resample'. + :param input_transformation_dir: (optional) path of a directory containing all the input transformation (for + reg_resample, or reg_f3d). Can also be the path to a single transformation that will be used for all images + in image_dir (set same_transformation to True in that case). + :param result_dir: path of directory where output images will be writen. + :param result_transformation_dir: path of directory where resulting transformations will be writen (for + reg_aladin and reg_f3d). + :param interpolation: (optional) integer describing the order of the interpolation to apply (0 = nearest neighbours) + :param same_floating: (optional) set to true if only one image is used as floating image. + :param same_reference: (optional) whether to use a single image as reference for all input images. + :param same_transformation: (optional) whether to apply the same transformation to all floating images. + :param path_nifty_reg: (optional) path of the folder containing nifty-reg functions + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dirs + if result_dir is not None: + utils.mkdir(result_dir) + if result_transformation_dir is not None: + utils.mkdir(result_transformation_dir) + + nifty_reg = os.path.join(path_nifty_reg, nifty_reg_function) + + # list reference and floating images + path_images = utils.list_images_in_folder(image_dir) + path_references = utils.list_images_in_folder(reference_dir) + if same_reference: + path_references = utils.reformat_to_list(path_references, length=len(path_images)) + if same_floating: + path_images = utils.reformat_to_list(path_images, length=len(path_references)) + assert len(path_references) == len(path_images), 'different number of files in image_dir and reference_dir' + + # list input transformations + if input_transformation_dir is not None: + if same_transformation: + path_input_transfs = utils.reformat_to_list(input_transformation_dir, length=len(path_images)) + else: + path_input_transfs = utils.list_files(input_transformation_dir) + assert len(path_input_transfs) == len(path_images), 'different number of transformations and images' + else: + path_input_transfs = [None] * len(path_images) + + # define flag input trans + if input_transformation_dir is not None: + if nifty_reg_function == 'reg_aladin': + flag_input_trans = '-inaff' + elif nifty_reg_function == 'reg_f3d': + flag_input_trans = '-aff' + elif nifty_reg_function == 'reg_resample': + flag_input_trans = '-trans' + else: + raise Exception('nifty_reg_function can only be "reg_aladin", "reg_f3d", or "reg_resample"') + else: + flag_input_trans = None + + # define flag result transformation + if result_transformation_dir is not None: + if nifty_reg_function == 'reg_aladin': + flag_result_trans = '-aff' + elif nifty_reg_function == 'reg_f3d': + flag_result_trans = '-cpp' + else: + raise Exception('result_transformation_dir can only be used with "reg_aladin" or "reg_f3d"') + else: + flag_result_trans = None + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + for idx, (path_image, path_ref, path_input_trans) in enumerate(zip(path_images, + path_references, + path_input_transfs)): + loop_info.update(idx) + + # define path registered image + name = os.path.basename(path_ref) if same_floating else os.path.basename(path_image) + if result_dir is not None: + path_result = os.path.join(result_dir, name) + result_already_computed = os.path.isfile(path_result) + else: + path_result = None + result_already_computed = True + + # define path resulting transformation + if result_transformation_dir is not None: + if nifty_reg_function == 'reg_aladin': + path_result_trans = os.path.join(result_transformation_dir, utils.strip_extension(name) + '.txt') + result_trans_already_computed = os.path.isfile(path_result_trans) + else: + path_result_trans = os.path.join(result_transformation_dir, name) + result_trans_already_computed = os.path.isfile(path_result_trans) + else: + path_result_trans = None + result_trans_already_computed = True + + if (not result_already_computed) | (not result_trans_already_computed) | recompute: + + # build main command + cmd = utils.mkcmd(nifty_reg, '-ref', path_ref, '-flo', path_image, '-pad 0') + + # add options + if path_result is not None: + cmd = utils.mkcmd(cmd, '-res', path_result) + if flag_input_trans is not None: + cmd = utils.mkcmd(cmd, flag_input_trans, path_input_trans) + if flag_result_trans is not None: + cmd = utils.mkcmd(cmd, flag_result_trans, path_result_trans) + if interpolation is not None: + cmd = utils.mkcmd(cmd, '-inter', interpolation) + + # execute + os.system(cmd) + + +def upsample_anisotropic_images(image_dir, + resample_image_result_dir, + resample_like_dir, + path_freesurfer='/usr/local/freesurfer/', + recompute=True): + """This function takes as input a set of LR images and resample them to HR with respect to reference images. + :param image_dir: path of directory with input images (only uni-modal images supported) + :param resample_image_result_dir: path of directory where resampled images will be writen + :param resample_like_dir: path of directory with reference images. + :param path_freesurfer: (optional) path freesurfer home, as this function uses mri_convert + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(resample_image_result_dir) + + # set up FreeSurfer + os.environ['FREESURFER_HOME'] = path_freesurfer + os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) + mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + + # list images and labels + path_images = utils.list_images_in_folder(image_dir) + path_ref_images = utils.list_images_in_folder(resample_like_dir) + assert len(path_images) == len(path_ref_images), \ + 'the folders containing the images and their references are not the same size' + + # loop over images + loop_info = utils.LoopInfo(len(path_images), 10, 'upsampling', True) + for idx, (path_image, path_ref) in enumerate(zip(path_images, path_ref_images)): + loop_info.update(idx) + + # upsample image + _, _, n_dims, _, _, image_res = utils.get_volume_info(path_image, return_volume=False) + path_im_upsampled = os.path.join(resample_image_result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_im_upsampled)) | recompute: + cmd = utils.mkcmd(mri_convert, path_image, path_im_upsampled, '-rl', path_ref, '-odt float') + os.system(cmd) + + path_dist_map = os.path.join(resample_image_result_dir, 'dist_map_' + os.path.basename(path_image)) + if (not os.path.isfile(path_dist_map)) | recompute: + im, aff, h = utils.load_volume(path_image, im_only=False) + dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing='ij') + tmp_dir = utils.strip_extension(path_im_upsampled) + '_meshes' + utils.mkdir(tmp_dir) + path_meshes_up = list() + for (i, maps) in enumerate(dist_map): + path_mesh = os.path.join(tmp_dir, '%s_' % i + os.path.basename(path_image)) + path_mesh_up = os.path.join(tmp_dir, 'up_%s_' % i + os.path.basename(path_image)) + utils.save_volume(maps, aff, h, path_mesh) + cmd = utils.mkcmd(mri_convert, path_mesh, path_mesh_up, '-rl', path_im_upsampled, '-odt float') + os.system(cmd) + path_meshes_up.append(path_mesh_up) + mesh_up_0, aff, h = utils.load_volume(path_meshes_up[0], im_only=False) + mesh_up = np.stack([mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1) + shutil.rmtree(tmp_dir) + + floor = np.floor(mesh_up) + ceil = np.ceil(mesh_up) + f_dist = mesh_up - floor + c_dist = ceil - mesh_up + dist = np.minimum(f_dist, c_dist) * utils.add_axis(image_res, axis=[0] * n_dims) + dist = np.sqrt(np.sum(dist ** 2, axis=-1)) + utils.save_volume(dist, aff, h, path_dist_map) + + +def simulate_upsampled_anisotropic_images(image_dir, + downsample_image_result_dir, + resample_image_result_dir, + data_res, + labels_dir=None, + downsample_labels_result_dir=None, + slice_thickness=None, + build_dist_map=False, + path_freesurfer='/usr/local/freesurfer/', + gpu=True, + recompute=True): + """This function takes as input a set of HR images and creates two datasets with it: + 1) a set of LR images obtained by downsampling the HR images with nearest neighbour interpolation, + 2) a set of HR images obtained by resampling the LR images to native HR with linear interpolation. + Additionally, this function can also create a set of LR labels from label maps corresponding to the input images. + :param image_dir: path of directory with input images (only uni-model images supported) + :param downsample_image_result_dir: path of directory where downsampled images will be writen + :param resample_image_result_dir: path of directory where resampled images will be writen + :param data_res: resolution of LR images. Can either be: an int, a float, a list or a numpy array. + :param labels_dir: (optional) path of directory with label maps corresponding to input images + :param downsample_labels_result_dir: (optional) path of directory where downsampled label maps will be writen + :param slice_thickness: (optional) thickness of slices to simulate. Can be a number, a list or a numpy array. + :param build_dist_map: (optional) whether to return the resampled images with an additional channel indicating the + distance of each voxel to the nearest acquired voxel. Default is False. + :param path_freesurfer: (optional) path freesurfer home, as this function uses mri_convert + :param gpu: (optional) whether to use a fast gpu model for blurring + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(resample_image_result_dir) + utils.mkdir(downsample_image_result_dir) + if labels_dir is not None: + assert downsample_labels_result_dir is not None, \ + 'downsample_labels_result_dir should not be None if labels_dir is specified' + utils.mkdir(downsample_labels_result_dir) + + # set up FreeSurfer + os.environ['FREESURFER_HOME'] = path_freesurfer + os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) + mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + + # list images and labels + path_images = utils.list_images_in_folder(image_dir) + path_labels = [None] * len(path_images) if labels_dir is None else utils.list_images_in_folder(labels_dir) + + # initialisation + _, _, n_dims, _, _, image_res = utils.get_volume_info(path_images[0], return_volume=False, aff_ref=np.eye(4)) + data_res = np.squeeze(utils.reformat_to_n_channels_array(data_res, n_dims, n_channels=1)) + slice_thickness = utils.reformat_to_list(slice_thickness, length=n_dims) + + # loop over images + previous_model_input_shape = None + model = None + loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + for idx, (path_image, path_labels) in enumerate(zip(path_images, path_labels)): + loop_info.update(idx) + + # downsample image + path_im_downsampled = os.path.join(downsample_image_result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_im_downsampled)) | recompute: + im, _, aff, n_dims, _, h, image_res = utils.get_volume_info(path_image, return_volume=True) + im, aff_aligned = align_volume_to_ref(im, aff, aff_ref=np.eye(4), return_aff=True, n_dims=n_dims) + im_shape = list(im.shape[:n_dims]) + sigma = blurring_sigma_for_downsampling(image_res, data_res, thickness=slice_thickness) + sigma = [0 if data_res[i] == image_res[i] else sigma[i] for i in range(n_dims)] + + # blur image + if gpu: + if (im_shape != previous_model_input_shape) | (model is None): + previous_model_input_shape = im_shape + image_in = KL.Input(shape=im_shape + [1]) + image = GaussianBlur(sigma=sigma)(image_in) + model = Model(inputs=image_in, outputs=image) + im = np.squeeze(model.predict(utils.add_axis(im, axis=[0, -1]))) + else: + im = blur_volume(im, sigma, mask=None) + utils.save_volume(im, aff_aligned, h, path_im_downsampled) + + # downsample blurred image + voxsize = ' '.join([str(r) for r in data_res]) + cmd = utils.mkcmd(mri_convert, path_im_downsampled, path_im_downsampled, '--voxsize', voxsize, + '-odt float -rt nearest') + os.system(cmd) + + # downsample labels if necessary + if path_labels is not None: + path_lab_downsampled = os.path.join(downsample_labels_result_dir, os.path.basename(path_labels)) + if (not os.path.isfile(path_lab_downsampled)) | recompute: + cmd = utils.mkcmd(mri_convert, path_labels, path_lab_downsampled, '-rl', path_im_downsampled, + '-odt float -rt nearest') + os.system(cmd) + + # upsample image + path_im_upsampled = os.path.join(resample_image_result_dir, os.path.basename(path_image)) + if (not os.path.isfile(path_im_upsampled)) | recompute: + cmd = utils.mkcmd(mri_convert, path_im_downsampled, path_im_upsampled, '-rl', path_image, '-odt float') + os.system(cmd) + + if build_dist_map: + path_dist_map = os.path.join(resample_image_result_dir, 'dist_map_' + os.path.basename(path_image)) + if (not os.path.isfile(path_dist_map)) | recompute: + im, aff, h = utils.load_volume(path_im_downsampled, im_only=False) + dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing='ij') + tmp_dir = utils.strip_extension(path_im_downsampled) + '_meshes' + utils.mkdir(tmp_dir) + path_meshes_up = list() + for (i, d_map) in enumerate(dist_map): + path_mesh = os.path.join(tmp_dir, '%s_' % i + os.path.basename(path_image)) + path_mesh_up = os.path.join(tmp_dir, 'up_%s_' % i + os.path.basename(path_image)) + utils.save_volume(d_map, aff, h, path_mesh) + cmd = utils.mkcmd(mri_convert, path_mesh, path_mesh_up, '-rl', path_image, '-odt float') + os.system(cmd) + path_meshes_up.append(path_mesh_up) + mesh_up_0, aff, h = utils.load_volume(path_meshes_up[0], im_only=False) + mesh_up = np.stack([mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1) + shutil.rmtree(tmp_dir) + + floor = np.floor(mesh_up) + ceil = np.ceil(mesh_up) + f_dist = mesh_up - floor + c_dist = ceil - mesh_up + dist = np.minimum(f_dist, c_dist) * utils.add_axis(data_res, axis=[0] * n_dims) + dist = np.sqrt(np.sum(dist ** 2, axis=-1)) + utils.save_volume(dist, aff, h, path_dist_map) + + +def check_images_in_dir(image_dir, check_values=False, keep_unique=True, max_channels=10, verbose=True): + """Check if all volumes within the same folder share the same characteristics: shape, affine matrix, resolution. + Also have option to check if all volumes have the same intensity values (useful for label maps). + :return four lists, each containing the different values detected for a specific parameter among those to check.""" + + # define information to check + list_shape = list() + list_aff = list() + list_res = list() + list_axes = list() + if check_values: + list_unique_values = list() + else: + list_unique_values = None + + # loop through files + path_images = utils.list_images_in_folder(image_dir) + loop_info = utils.LoopInfo(len(path_images), 10, 'checking', verbose) if verbose else None + for idx, path_image in enumerate(path_images): + if loop_info is not None: + loop_info.update(idx) + + # get info + im, shape, aff, n_dims, _, h, res = utils.get_volume_info(path_image, True, np.eye(4), max_channels) + axes = get_ras_axes(aff, n_dims=n_dims).tolist() + aff[:, np.arange(n_dims)] = aff[:, axes] + aff = (np.int32(np.round(np.array(aff[:3, :3]), 2) * 100) / 100).tolist() + res = (np.int32(np.round(np.array(res), 2) * 100) / 100).tolist() + + # add values to list if not already there + if (shape not in list_shape) | (not keep_unique): + list_shape.append(shape) + if (aff not in list_aff) | (not keep_unique): + list_aff.append(aff) + if (res not in list_res) | (not keep_unique): + list_res.append(res) + if (axes not in list_axes) | (not keep_unique): + list_axes.append(axes) + if list_unique_values is not None: + uni = np.unique(im).tolist() + if (uni not in list_unique_values) | (not keep_unique): + list_unique_values.append(uni) + + return list_shape, list_aff, list_res, list_axes, list_unique_values + + +# ----------------------------------------------- edit label maps in dir ----------------------------------------------- + +def correct_labels_in_dir(labels_dir, results_dir, incorrect_labels, correct_labels=None, + use_nearest_label=False, remove_zero=False, smooth=False, recompute=True): + """This function corrects label values for all label maps in a folder with either + - a list a given values, + - or with the nearest label value. + :param labels_dir: path of directory with input label maps + :param results_dir: path of directory where corrected label maps will be writen + :param incorrect_labels: list of all label values to correct (e.g. [1, 2, 3, 4]). + :param correct_labels: (optional) list of correct label values to replace the incorrect ones. + Correct values must have the same order as their corresponding value in list_incorrect_labels. + When several correct values are possible for the same incorrect value, the nearest correct value will be selected at + each voxel to correct. In that case, the different correct values must be specified inside a list within + list_correct_labels (e.g. [10, 20, 30, [40, 50]). + :param use_nearest_label: (optional) whether to correct the incorrect label values with the nearest labels. + :param remove_zero: (optional) if use_nearest_label is True, set to True not to consider zero among the potential + candidates for the nearest neighbour. + :param smooth: (optional) whether to smooth the corrected label maps + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(results_dir) + + # prepare data files + path_labels = utils.list_images_in_folder(labels_dir) + loop_info = utils.LoopInfo(len(path_labels), 10, 'correcting', True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # correct labels + path_result = os.path.join(results_dir, os.path.basename(path_label)) + if (not os.path.isfile(path_result)) | recompute: + im, aff, h = utils.load_volume(path_label, im_only=False, dtype='int32') + im = correct_label_map(im, incorrect_labels, correct_labels, use_nearest_label, remove_zero, smooth) + utils.save_volume(im, aff, h, path_result) + + +def mask_labels_in_dir(labels_dir, result_dir, values_to_keep, masking_value=0, mask_result_dir=None, recompute=True): + """This function masks all label maps in a folder by keeping a set of given label values. + :param labels_dir: path of directory with input label maps + :param result_dir: path of directory where corrected label maps will be writen + :param values_to_keep: list of values for masking the label maps. + :param masking_value: (optional) value to mask the label maps with + :param mask_result_dir: (optional) path of directory where applied masks will be writen + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + if mask_result_dir is not None: + utils.mkdir(mask_result_dir) + + # reformat values to keep + values_to_keep = utils.reformat_to_list(values_to_keep, load_as_numpy=True) + + # loop over labels + path_labels = utils.list_images_in_folder(labels_dir) + loop_info = utils.LoopInfo(len(path_labels), 10, 'masking', True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # mask labels + path_result = os.path.join(result_dir, os.path.basename(path_label)) + if mask_result_dir is not None: + path_result_mask = os.path.join(mask_result_dir, os.path.basename(path_label)) + else: + path_result_mask = '' + if (not os.path.isfile(path_result)) | \ + (mask_result_dir is not None) & (not os.path.isfile(path_result_mask)) | \ + recompute: + lab, aff, h = utils.load_volume(path_label, im_only=False) + if mask_result_dir is not None: + labels, mask = mask_label_map(lab, values_to_keep, masking_value, return_mask=True) + path_result_mask = os.path.join(mask_result_dir, os.path.basename(path_label)) + utils.save_volume(mask, aff, h, path_result_mask) + else: + labels = mask_label_map(lab, values_to_keep, masking_value, return_mask=False) + utils.save_volume(labels, aff, h, path_result) + + +def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, connectivity=1, recompute=True): + """Smooth all label maps in a folder by replacing each voxel by the value of its most numerous neighbours. + :param labels_dir: path of directory with input label maps + :param result_dir: path of directory where smoothed label maps will be writen + :param gpu: (optional) whether to use a gpu implementation for faster processing + :param labels_list: (optional) if gpu is True, path of numpy array with all label values. + Automatically computed if not provided. + :param connectivity: (optional) connectivity to use when smoothing the label maps + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # create result dir + utils.mkdir(result_dir) + + # list label maps + path_labels = utils.list_images_in_folder(labels_dir) + + if labels_list is not None: + labels_list, _ = utils.get_list_labels(label_list=labels_list, FS_sort=True) + + if gpu: + # initialisation + previous_model_input_shape = None + smoothing_model = None + + # loop over label maps + loop_info = utils.LoopInfo(len(path_labels), 10, 'smoothing', True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # smooth label map + path_result = os.path.join(result_dir, os.path.basename(path_label)) + if (not os.path.isfile(path_result)) | recompute: + labels, label_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_label, return_volume=True) + if label_shape != previous_model_input_shape: + previous_model_input_shape = label_shape + smoothing_model = smoothing_gpu_model(label_shape, labels_list, connectivity) + unique_labels = np.unique(labels).astype('int32') + if labels_list is None: + smoothed_labels = smoothing_model.predict(utils.add_axis(labels)) + else: + labels_to_keep = [lab for lab in unique_labels if lab not in labels_list] + new_labels, mask_new_labels = mask_label_map(labels, labels_to_keep, return_mask=True) + smoothed_labels = np.squeeze(smoothing_model.predict(utils.add_axis(labels))) + smoothed_labels = np.where(mask_new_labels, new_labels, smoothed_labels) + mask_new_zeros = (labels > 0) & (smoothed_labels == 0) + smoothed_labels[mask_new_zeros] = labels[mask_new_zeros] + utils.save_volume(smoothed_labels, aff, h, path_result, dtype='int32') + + else: + # build kernel + _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) + kernel = utils.build_binary_structure(connectivity, n_dims, shape=n_dims) + + # loop over label maps + loop_info = utils.LoopInfo(len(path_labels), 10, 'smoothing', True) + for idx, path in enumerate(path_labels): + loop_info.update(idx) + + # smooth label map + path_result = os.path.join(result_dir, os.path.basename(path)) + if (not os.path.isfile(path_result)) | recompute: + volume, aff, h = utils.load_volume(path, im_only=False) + new_volume = smooth_label_map(volume, kernel, labels_list) + utils.save_volume(new_volume, aff, h, path_result, dtype='int32') + + +def smoothing_gpu_model(label_shape, label_list, connectivity=1): + """This function builds a gpu model in keras with a tensorflow backend to smooth label maps. + This model replaces each voxel of the input by the value of its most numerous neighbour. + :param label_shape: shape of the label map + :param label_list: list of all labels to consider + :param connectivity: (optional) connectivity to use when smoothing the label maps + :return: gpu smoothing model + """ + + # convert labels so values are in [0, ..., N-1] and use one hot encoding + n_labels = label_list.shape[0] + labels_in = KL.Input(shape=label_shape, name='lab_input', dtype='int32') + labels = ConvertLabels(label_list)(labels_in) + labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels) + + # count neighbouring voxels + n_dims, _ = utils.get_dims(label_shape) + k = utils.add_axis(utils.build_binary_structure(connectivity, n_dims, shape=n_dims), axis=[-1, -1]) + kernel = KL.Lambda(lambda x: tf.convert_to_tensor(k, dtype='float32'))([]) + split = KL.Lambda(lambda x: tf.split(x, [1] * n_labels, axis=-1))(labels) + labels = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding='SAME'))([split[0], kernel]) + for i in range(1, n_labels): + tmp = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding='SAME'))([split[i], kernel]) + labels = KL.Lambda(lambda x: tf.concat([x[0], x[1]], -1))([labels, tmp]) + + # take the argmax and convert labels to original values + labels = KL.Lambda(lambda x: tf.math.argmax(x, -1))(labels) + labels = ConvertLabels(np.arange(n_labels), label_list)(labels) + return Model(inputs=labels_in, outputs=labels) + + +def erode_labels_in_dir(labels_dir, result_dir, labels_to_erode, erosion_factors=1., gpu=False, recompute=True): + """Erode a given set of label values for all label maps in a folder. + :param labels_dir: path of directory with input label maps + :param result_dir: path of directory where cropped label maps will be writen + :param labels_to_erode: list of label values to erode + :param erosion_factors: (optional) list of erosion factors to use for each label value. If values are integers, + normal erosion applies. If float, we first 1) blur a mask of the corresponding label value with a gpu model, + and 2) use the erosion factor as a threshold in the blurred mask. + If erosion_factors is a single value, the same factor will be applied to all labels. + :param gpu: (optional) whether to use a fast gpu model for blurring (if erosion factors are floats) + :param recompute: (optional) whether to recompute result files even if they already exists + """ + # create result dir + utils.mkdir(result_dir) + + # loop over label maps + model = None + path_labels = utils.list_images_in_folder(labels_dir) + loop_info = utils.LoopInfo(len(path_labels), 5, 'eroding', True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # erode label map + labels, aff, h = utils.load_volume(path_label, im_only=False) + path_result = os.path.join(result_dir, os.path.basename(path_label)) + if (not os.path.isfile(path_result)) | recompute: + labels, model = erode_label_map(labels, labels_to_erode, erosion_factors, gpu, model, return_model=True) + utils.save_volume(labels, aff, h, path_result) + + +def upsample_labels_in_dir(labels_dir, + target_res, + result_dir, + path_label_list=None, + path_freesurfer='/usr/local/freesurfer/', + recompute=True): + """This function upsamples all label maps within a folder. Importantly, each label map is converted into probability + maps for all label values, and all these maps are upsampled separately. The upsampled label maps are recovered by + taking the argmax of the label values probability maps. + :param labels_dir: path of directory with label maps to upsample + :param target_res: resolution at which to upsample the label maps. can be a single number (isotropic), or a list. + :param result_dir: path of directory where the upsampled label maps will be writen + :param path_label_list: (optional) path of numpy array containing all label values. + Computed automatically if not given. + :param path_freesurfer: (optional) path freesurfer home (upsampling performed with mri_convert) + :param recompute: (optional) whether to recompute result files even if they already exists + """ + + # prepare result dir + utils.mkdir(result_dir) + + # set up FreeSurfer + os.environ['FREESURFER_HOME'] = path_freesurfer + os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) + mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + + # list label maps + path_labels = utils.list_images_in_folder(labels_dir) + labels_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_labels[0], max_channels=3) + + # build command + target_res = utils.reformat_to_list(target_res, length=n_dims) + post_cmd = '-voxsize ' + ' '.join([str(r) for r in target_res]) + ' -odt float' + + # load label list and corresponding LUT to make sure that labels go from 0 to N-1 + label_list, _ = utils.get_list_labels(path_label_list, labels_dir=path_labels, FS_sort=False) + new_label_list = np.arange(len(label_list), dtype='int32') + lut = utils.get_mapping_lut(label_list) + + # loop over label maps + loop_info = utils.LoopInfo(len(path_labels), 5, 'upsampling', True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + path_result = os.path.join(result_dir, os.path.basename(path_label)) + if (not os.path.isfile(path_result)) | recompute: + + # load volume + labels, aff, h = utils.load_volume(path_label, im_only=False) + labels = lut[labels.astype('int')] + + # create individual folders for label map + basefilename = utils.strip_extension(os.path.basename(path_label)) + indiv_label_dir = os.path.join(result_dir, basefilename) + upsample_indiv_label_dir = os.path.join(result_dir, basefilename + '_upsampled') + utils.mkdir(indiv_label_dir) + utils.mkdir(upsample_indiv_label_dir) + + # loop over label values + for label in new_label_list: + path_mask = os.path.join(indiv_label_dir, str(label) + '.nii.gz') + path_mask_upsampled = os.path.join(upsample_indiv_label_dir, str(label) + '.nii.gz') + if not os.path.isfile(path_mask): + mask = (labels == label) * 1.0 + utils.save_volume(mask, aff, h, path_mask) + if not os.path.isfile(path_mask_upsampled): + cmd = utils.mkcmd(mri_convert, path_mask, path_mask_upsampled, post_cmd) + os.system(cmd) + + # compute argmax of upsampled probability maps (upload them one at a time) + probmax, aff, h = utils.load_volume(os.path.join(upsample_indiv_label_dir, '0.nii.gz'), im_only=False) + labels = np.zeros(probmax.shape, dtype='int') + for label in new_label_list: + prob = utils.load_volume(os.path.join(upsample_indiv_label_dir, str(label) + '.nii.gz')) + idx = prob > probmax + labels[idx] = label + probmax[idx] = prob[idx] + utils.save_volume(label_list[labels], aff, h, path_result, dtype='int32') + + +def compute_hard_volumes_in_dir(labels_dir, + voxel_volume=None, + path_label_list=None, + skip_background=True, + path_numpy_result=None, + path_csv_result=None, + FS_sort=False): + """Compute hard volumes of structures for all label maps in a folder. + :param labels_dir: path of directory with input label maps + :param voxel_volume: (optional) volume of the voxels. If None, it will be directly inferred from the file header. + Set to 1 for a voxel count. + :param path_label_list: (optional) list of labels to compute volumes for. + Can be an int, a sequence, or a numpy array. If None, the volumes of all label values are computed for each subject. + :param skip_background: (optional) whether to skip computing the volume of the background. + If label_list is None, this assumes background value is 0. + If label_list is not None, this assumes the background is the first value in label list. + :param path_numpy_result: (optional) path where to write the result volumes as a numpy array. + :param path_csv_result: (optional) path where to write the results as csv file. + :param FS_sort: (optional) whether to sort the labels in FreeSurfer order. + :return: numpy array with the volume of each structure for all subjects. + Rows represent label values, and columns represent subjects. + """ + + # create result directories + if path_numpy_result is not None: + utils.mkdir(os.path.dirname(path_numpy_result)) + if path_csv_result is not None: + utils.mkdir(os.path.dirname(path_csv_result)) + + # load or compute labels list + label_list, _ = utils.get_list_labels(path_label_list, labels_dir, FS_sort=FS_sort) + + # create csv volume file if necessary + if path_csv_result is not None: + if skip_background: + cvs_header = [['subject'] + [str(lab) for lab in label_list[1:]]] + else: + cvs_header = [['subject'] + [str(lab) for lab in label_list]] + with open(path_csv_result, 'w') as csvFile: + writer = csv.writer(csvFile) + writer.writerows(cvs_header) + csvFile.close() + + # loop over label maps + path_labels = utils.list_images_in_folder(labels_dir) + if skip_background: + volumes = np.zeros((label_list.shape[0] - 1, len(path_labels))) + else: + volumes = np.zeros((label_list.shape[0], len(path_labels))) + loop_info = utils.LoopInfo(len(path_labels), 10, 'processing', True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # load segmentation, and compute unique labels + labels, _, _, _, _, _, subject_res = utils.get_volume_info(path_label, return_volume=True) + if voxel_volume is None: + voxel_volume = float(np.prod(subject_res)) + subject_volumes = compute_hard_volumes(labels, voxel_volume, label_list, skip_background) + volumes[:, idx] = subject_volumes + + # write volumes + if path_csv_result is not None: + subject_volumes = np.around(volumes[:, idx], 3) + row = [utils.strip_suffix(os.path.basename(path_label))] + [str(vol) for vol in subject_volumes] + with open(path_csv_result, 'a') as csvFile: + writer = csv.writer(csvFile) + writer.writerow(row) + csvFile.close() + + # write numpy array if necessary + if path_numpy_result is not None: + np.save(path_numpy_result, volumes) + + return volumes + + +def build_atlas(labels_dir, + label_list, + align_centre_of_mass=False, + margin=15, + shape=None, + path_atlas=None): + """This function builds a binary atlas (defined by label values > 0) from several label maps. + :param labels_dir: path of directory with input label maps + :param label_list: list of all labels in the label maps. If there is more than 1 value here, the different channels + of the atlas (each corresponding to the probability map of a given label) will in the same order as in this list. + :param align_centre_of_mass: whether to build the atlas by aligning the center of mass of each label map. + If False, the atlas has the same size as the input label maps, which are assumed to be aligned. + :param margin: (optional) If align_centre_of_mass is True, margin by which to crop the input label maps around + their center of mass. Therefore it controls the size of the output atlas: (2*margin + 1)**n_dims. + :param shape: shape of the output atlas. + :param path_atlas: (optional) path where the output atlas will be writen. + Default is None, where the atlas is not saved.""" + + # list of all label maps and create result dir + path_labels = utils.list_images_in_folder(labels_dir) + n_label_maps = len(path_labels) + utils.mkdir(os.path.dirname(path_atlas)) + + # read list labels and create lut + label_list = np.array(utils.reformat_to_list(label_list, load_as_numpy=True, dtype='int')) + lut = utils.get_mapping_lut(label_list) + n_labels = len(label_list) + + # create empty atlas + im_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_labels[0], aff_ref=np.eye(4)) + if align_centre_of_mass: + shape = [margin * 2] * n_dims + else: + shape = utils.reformat_to_list(shape, length=n_dims) if shape is not None else im_shape + shape = shape + [n_labels] if n_labels > 1 else shape + atlas = np.zeros(shape) + + # loop over label maps + loop_info = utils.LoopInfo(n_label_maps, 10, 'processing', True) + for idx, path_label in enumerate(path_labels): + loop_info.update(idx) + + # load label map and build mask + lab = utils.load_volume(path_label, dtype='int32', aff_ref=np.eye(4)) + lab = correct_label_map(lab, [31, 63, 72], [4, 43, 0]) + lab = lut[lab.astype('int')] + lab = pad_volume(lab, shape[:n_dims]) + lab = crop_volume(lab, cropping_shape=shape[:n_dims]) + indices = np.where(lab > 0) + + if len(label_list) > 1: + lab = np.identity(n_labels)[lab] + + # crop label map around centre of mass + if align_centre_of_mass: + centre_of_mass = np.array([np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])], dtype='int32') + min_crop = centre_of_mass - margin + max_crop = centre_of_mass + margin + atlas += lab[min_crop[0]:max_crop[0], min_crop[1]:max_crop[1], min_crop[2]:max_crop[2], ...] + # otherwise just add the one-hot labels + else: + atlas += lab + + # normalise atlas and save it if necessary + atlas /= n_label_maps + atlas = align_volume_to_ref(atlas, np.eye(4), aff_ref=aff, n_dims=n_dims) + if path_atlas is not None: + utils.save_volume(atlas, aff, h, path_atlas) + + return atlas + + +# ---------------------------------------------------- edit dataset ---------------------------------------------------- + +def check_images_and_labels(image_dir, labels_dir, verbose=True): + """Check if corresponding images and labels have the same affine matrices and shapes. + Labels are matched to images by sorting order. + :param image_dir: path of directory with input images + :param labels_dir: path of directory with corresponding label maps + :param verbose: whether to print out info + """ + + # list images and labels + path_images = utils.list_images_in_folder(image_dir) + path_labels = utils.list_images_in_folder(labels_dir) + assert len(path_images) == len(path_labels), 'different number of files in image_dir and labels_dir' + + # loop over images and labels + loop_info = utils.LoopInfo(len(path_images), 10, 'checking', verbose) if verbose else None + for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): + if loop_info is not None: + loop_info.update(idx) + + # load images and labels + im, aff_im, h_im = utils.load_volume(path_image, im_only=False) + lab, aff_lab, h_lab = utils.load_volume(path_label, im_only=False) + aff_im_list = np.round(aff_im, 2).tolist() + aff_lab_list = np.round(aff_lab, 2).tolist() + + # check matching affine and shape + if aff_lab_list != aff_im_list: + print('aff mismatch :\n' + path_image) + print(aff_im_list) + print(path_label) + print(aff_lab_list) + print('') + if lab.shape != im.shape: + print('shape mismatch :\n' + path_image) + print(im.shape) + print('\n' + path_label) + print(lab.shape) + print('') + + +def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_result_dir=None, margin=5): + """Crop all label maps in a directory to the minimum possible common size, with a margin. + This is achieved by cropping each label map individually to the minimum size, and by padding all the cropped maps to + the same size (taken to be the maximum size of the cropped maps). + If images are provided, they undergo the same transformations as their corresponding label maps. + :param labels_dir: path of directory with input label maps + :param result_dir: path of directory where cropped label maps will be writen + :param image_dir: (optional) if not None, the cropping will be applied to all images in this directory + :param image_result_dir: (optional) path of directory where cropped images will be writen + :param margin: (optional) margin to apply around the label maps during cropping + """ + + # create result dir + utils.mkdir(result_dir) + if image_dir is not None: + assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + utils.mkdir(image_result_dir) + + # list labels and images + path_labels = utils.list_images_in_folder(labels_dir) + if image_dir is not None: + path_images = utils.list_images_in_folder(image_dir) + else: + path_images = [None] * len(path_labels) + _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) + + # loop over label maps for cropping + print('\ncropping labels to individual minimum size') + maximum_size = np.zeros(n_dims) + loop_info = utils.LoopInfo(len(path_labels), 10, 'cropping', True) + for idx, (path_label, path_image) in enumerate(zip(path_labels, path_images)): + loop_info.update(idx) + + # crop label maps and update maximum size of cropped map + label, aff, h = utils.load_volume(path_label, im_only=False) + label, cropping, aff = crop_volume_around_region(label, aff=aff) + utils.save_volume(label, aff, h, os.path.join(result_dir, os.path.basename(path_label))) + maximum_size = np.maximum(maximum_size, np.array(label.shape) + margin * 2) # *2 to add margin on each side + + # crop images if required + if path_image is not None: + image, aff_im, h_im = utils.load_volume(path_image, im_only=False) + image, aff_im = crop_volume_with_idx(image, cropping, aff=aff_im) + utils.save_volume(image, aff_im, h_im, os.path.join(image_result_dir, os.path.basename(path_image))) + + # loop over label maps for padding + print('\npadding labels to same size') + loop_info = utils.LoopInfo(len(path_labels), 10, 'padding', True) + for idx, (path_label, path_image) in enumerate(zip(path_labels, path_images)): + loop_info.update(idx) + + # pad label maps to maximum size + path_result = os.path.join(result_dir, os.path.basename(path_label)) + label, aff, h = utils.load_volume(path_result, im_only=False) + label, aff = pad_volume(label, maximum_size, aff=aff) + utils.save_volume(label, aff, h, path_result) + + # crop images if required + if path_image is not None: + path_result = os.path.join(image_result_dir, os.path.basename(path_image)) + image, aff, h = utils.load_volume(path_result, im_only=False) + image, aff = pad_volume(image, maximum_size, aff=aff) + utils.save_volume(image, aff, h, path_result) + + +def crop_dataset_around_region_of_same_size(labels_dir, + result_dir, + image_dir=None, + image_result_dir=None, + margin=0, + recompute=True): + + # create result dir + utils.mkdir(result_dir) + if image_dir is not None: + assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + utils.mkdir(image_result_dir) + + # list labels and images + path_labels = utils.list_images_in_folder(labels_dir) + path_images = utils.list_images_in_folder(image_dir) if image_dir is not None else [None] * len(path_labels) + _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) + + recompute_labels = any([not os.path.isfile(os.path.join(result_dir, os.path.basename(path))) + for path in path_labels]) + if (image_dir is not None) & (not recompute_labels): + recompute_labels = any([not os.path.isfile(os.path.join(image_result_dir, os.path.basename(path))) + for path in path_images]) + + # get minimum patch shape so that no labels are left out when doing the cropping later on + max_crop_shape = np.zeros(n_dims) + if recompute_labels: + for path_label in path_labels: + label, aff, _ = utils.load_volume(path_label, im_only=False) + label = align_volume_to_ref(label, aff, aff_ref=np.eye(4)) + label = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + _, cropping = crop_volume_around_region(label) + max_crop_shape = np.maximum(cropping[n_dims:] - cropping[:n_dims], max_crop_shape) + max_crop_shape += np.array(utils.reformat_to_list(margin, length=n_dims, dtype='int')) + print('max_crop_shape: ', max_crop_shape) + + # crop shapes (possibly with padding if images are smaller than crop shape) + for path_label, path_image in zip(path_labels, path_images): + + path_label_result = os.path.join(result_dir, os.path.basename(path_label)) + path_image_result = os.path.join(image_result_dir, os.path.basename(path_image)) + + if (not os.path.isfile(path_image_result)) | (not os.path.isfile(path_label_result)) | recompute: + # load labels + label, aff, h_la = utils.load_volume(path_label, im_only=False, dtype='int32') + label, aff_new = align_volume_to_ref(label, aff, aff_ref=np.eye(4), return_aff=True) + vol_shape = np.array(label.shape[:n_dims]) + if path_image is not None: + image, _, h_im = utils.load_volume(path_image, im_only=False) + image = align_volume_to_ref(image, aff, aff_ref=np.eye(4)) + else: + image = h_im = None + + # mask labels + mask = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + label[np.logical_not(mask)] = 0 + + # find cropping indices + indices = np.nonzero(mask) + min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) + max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape) + + # expand/retract (depending on the desired shape) the cropping region around the centre + intermediate_vol_shape = max_idx - min_idx + min_idx = min_idx - np.int32(np.ceil((max_crop_shape - intermediate_vol_shape) / 2)) + max_idx = max_idx + np.int32(np.floor((max_crop_shape - intermediate_vol_shape) / 2)) + + # check if we need to pad the output to the desired shape + min_padding = np.abs(np.minimum(min_idx, 0)) + max_padding = np.maximum(max_idx - vol_shape, 0) + if np.any(min_padding > 0) | np.any(max_padding > 0): + pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) + else: + pad_margins = None + cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) + + # crop volume + label = crop_volume_with_idx(label, cropping, n_dims=n_dims) + if path_image is not None: + image = crop_volume_with_idx(image, cropping, n_dims=n_dims) + + # pad volume if necessary + if pad_margins is not None: + label = np.pad(label, pad_margins, mode='constant', constant_values=0) + if path_image is not None: + _, n_channels = utils.get_dims(image.shape) + pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins + image = np.pad(image, pad_margins, mode='constant', constant_values=0) + + # update aff + if n_dims == 2: + min_idx = np.append(min_idx, 0) + aff_new[0:3, -1] = aff_new[0:3, -1] + aff_new[:3, :3] @ min_idx + + # write labels + label, aff_final = align_volume_to_ref(label, aff_new, aff_ref=aff, return_aff=True) + utils.save_volume(label, aff_final, h_la, path_label_result, dtype='int32') + if path_image is not None: + image = align_volume_to_ref(image, aff_new, aff_ref=aff) + utils.save_volume(image, aff_final, h_im, path_image_result) + + +def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_result_dir, margin=0, + cropping_shape_div_by=None, recompute=True): + + # create result dir + utils.mkdir(image_result_dir) + utils.mkdir(labels_result_dir) + + # list volumes and masks + path_images = utils.list_images_in_folder(image_dir) + path_labels = utils.list_images_in_folder(labels_dir) + _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_labels[0]) + + # loop over images and labels + loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): + loop_info.update(idx) + + path_label_result = os.path.join(labels_result_dir, os.path.basename(path_label)) + path_image_result = os.path.join(image_result_dir, os.path.basename(path_image)) + + if (not os.path.isfile(path_label_result)) | (not os.path.isfile(path_image_result)) | recompute: + + image, aff, h_im = utils.load_volume(path_image, im_only=False) + label, _, h_lab = utils.load_volume(path_label, im_only=False) + mask = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + label[np.logical_not(mask)] = 0 + vol_shape = np.array(label.shape[:n_dims]) + + # find cropping indices + indices = np.nonzero(mask) + min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) + max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape) + + # expand/retract (depending on the desired shape) the cropping region around the centre + intermediate_vol_shape = max_idx - min_idx + cropping_shape = np.array([utils.find_closest_number_divisible_by_m(s, cropping_shape_div_by, + answer_type='higher') + for s in intermediate_vol_shape]) + min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2)) + max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2)) + + # check if we need to pad the output to the desired shape + min_padding = np.abs(np.minimum(min_idx, 0)) + max_padding = np.maximum(max_idx - vol_shape, 0) + if np.any(min_padding > 0) | np.any(max_padding > 0): + pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) + else: + pad_margins = None + cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) + + # crop volume + label = crop_volume_with_idx(label, cropping, n_dims=n_dims) + image = crop_volume_with_idx(image, cropping, n_dims=n_dims) + + # pad volume if necessary + if pad_margins is not None: + label = np.pad(label, pad_margins, mode='constant', constant_values=0) + pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins + image = np.pad(image, pad_margins, mode='constant', constant_values=0) + + # update aff + if n_dims == 2: + min_idx = np.append(min_idx, 0) + aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ min_idx + + # write results + utils.save_volume(image, aff, h_im, path_image_result) + utils.save_volume(label, aff, h_lab, path_label_result, dtype='int32') + + +def subdivide_dataset_to_patches(patch_shape, + image_dir=None, + image_result_dir=None, + labels_dir=None, + labels_result_dir=None, + full_background=True, + remove_after_dividing=False): + """This function subdivides images and/or label maps into several smaller patches of specified shape. + :param patch_shape: shape of patches to create. Can either be an int, a sequence, or a 1d numpy array. + :param image_dir: (optional) path of directory with input images + :param image_result_dir: (optional) path of directory where image patches will be writen + :param labels_dir: (optional) path of directory with input label maps + :param labels_result_dir: (optional) path of directory where label map patches will be writen + :param full_background: (optional) whether to keep patches only labelled as background (only if label maps are + provided). + :param remove_after_dividing: (optional) whether to delete input images after having divided them in smaller + patches. This enables to save disk space in the subdivision process. + """ + + # create result dir and list images and label maps + assert (image_dir is not None) | (labels_dir is not None), \ + 'at least one of image_dir or labels_dir should not be None.' + if image_dir is not None: + assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + utils.mkdir(image_result_dir) + path_images = utils.list_images_in_folder(image_dir) + else: + path_images = None + if labels_dir is not None: + assert labels_result_dir is not None, 'labels_result_dir should not be None if labels_dir is specified' + utils.mkdir(labels_result_dir) + path_labels = utils.list_images_in_folder(labels_dir) + else: + path_labels = None + if path_images is None: + path_images = [None] * len(path_labels) + if path_labels is None: + path_labels = [None] * len(path_images) + + # reformat path_shape + patch_shape = utils.reformat_to_list(patch_shape) + n_dims, _ = utils.get_dims(patch_shape) + + # loop over images and labels + loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): + loop_info.update(idx) + + # load image and labels + if path_image is not None: + im, aff_im, h_im = utils.load_volume(path_image, im_only=False, squeeze=False) + else: + im = aff_im = h_im = None + if path_label is not None: + lab, aff_lab, h_lab = utils.load_volume(path_label, im_only=False, squeeze=True) + else: + lab = aff_lab = h_lab = None + + # get volume shape + if path_image is not None: + shape = im.shape + else: + shape = lab.shape + + # crop image and label map to size divisible by patch_shape + new_size = np.array([utils.find_closest_number_divisible_by_m(shape[i], patch_shape[i]) for i in range(n_dims)]) + crop = np.round((np.array(shape[:n_dims]) - new_size) / 2).astype('int') + crop = np.concatenate((crop, crop + new_size), axis=0) + if (im is not None) & (n_dims == 2): + im = im[crop[0]:crop[2], crop[1]:crop[3], ...] + elif (im is not None) & (n_dims == 3): + im = im[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] + if (lab is not None) & (n_dims == 2): + lab = lab[crop[0]:crop[2], crop[1]:crop[3], ...] + elif (lab is not None) & (n_dims == 3): + lab = lab[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] + + # loop over patches + n_im = 0 + n_crop = (new_size / patch_shape).astype('int') + for i in range(n_crop[0]): + i *= patch_shape[0] + for j in range(n_crop[1]): + j *= patch_shape[1] + + if n_dims == 2: + + # crop volumes + if lab is not None: + temp_la = lab[i:i + patch_shape[0], j:j + patch_shape[1], ...] + else: + temp_la = None + if im is not None: + temp_im = im[i:i + patch_shape[0], j:j + patch_shape[1], ...] + else: + temp_im = None + + # write patches + if temp_la is not None: + if full_background | (not (temp_la == 0).all()): + n_im += 1 + utils.save_volume(temp_la, aff_lab, h_lab, os.path.join(labels_result_dir, + os.path.basename(path_label.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + if temp_im is not None: + utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, + os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + else: + utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, + os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + + elif n_dims == 3: + for k in range(n_crop[2]): + k *= patch_shape[2] + + # crop volumes + if lab is not None: + temp_la = lab[i:i + patch_shape[0], j:j + patch_shape[1], k:k + patch_shape[2], ...] + else: + temp_la = None + if im is not None: + temp_im = im[i:i + patch_shape[0], j:j + patch_shape[1], k:k + patch_shape[2], ...] + else: + temp_im = None + + # write patches + if temp_la is not None: + if full_background | (not (temp_la == 0).all()): + n_im += 1 + utils.save_volume(temp_la, aff_lab, h_lab, os.path.join(labels_result_dir, + os.path.basename(path_label.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + if temp_im is not None: + utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, + os.path.basename(path_image.replace('.nii.gz', + '_%d.nii.gz' % n_im)))) + else: + utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, + os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + + if remove_after_dividing: + if path_image is not None: + os.remove(path_image) + if path_label is not None: + os.remove(path_label) diff --git a/nobrainer/ext/lab2im/image_generator.py b/nobrainer/ext/lab2im/image_generator.py new file mode 100644 index 00000000..073442e8 --- /dev/null +++ b/nobrainer/ext/lab2im/image_generator.py @@ -0,0 +1,266 @@ +""" +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import numpy as np +import numpy.random as npr + +# project imports +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im import edit_volumes +from nobrainer.ext.lab2im.lab2im_model import lab2im_model + + +class ImageGenerator: + + def __init__(self, + labels_dir, + generation_labels=None, + output_labels=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + generation_classes=None, + prior_distributions='uniform', + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + blur_range=1.15): + """ + This class is wrapper around the lab2im_model model. It contains the GPU model that generates images from labels + maps, and a python generator that supplies the input data for this model. + To generate pairs of image/labels you can just call the method generate_image() on an object of this class. + + :param labels_dir: path of folder with all input label maps, or to a single label map. + + # IMPORTANT !!! + # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), + # these values refer to the RAS axes. + + # label maps-related parameters + :param generation_labels: (optional) list of all possible label values in the input label maps. + Default is None, where the label values are directly gotten from the provided label maps. + If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array. + :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in + the label maps returned by this function, i.e. all occurrences of generation_labels[i] in the input label maps + will be converted to output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + + # output-related parameters + :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. + :param n_channels: (optional) number of channels to be synthetised. Default is 1. + :param target_res: (optional) target resolution of the generated images and corresponding label maps. + If None, the outputs will have the same resolution as the input label maps. + Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array. + :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image. + Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites + output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or + the path to a 1d numpy array. + + # GMM-sampling parameters + :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity + distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence, a + 1d numpy array, or the path to a 1d numpy array. + It should have the same length as generation_labels, and contain values between 0 and K-1, where K is the total + number of classes. Default is all labels have different classes (K=len(generation_labels)). + :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. + Can either be 'uniform', or 'normal'. Default is 'uniform'. + :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because + these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: + 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is + uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each + mini_batch from the same distribution. + 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is + not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each + mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from + N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. + 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived + from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a + modality from the n_mod possibilities, and we sample the GMM means like in 2). + If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel + (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. + 4) the path to such a numpy array. + Default is None, which corresponds to prior_means = [25, 225]. + :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. + Default is None, which corresponds to prior_stds = [5, 25]. + :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be + only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. + + # blurring parameters + :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is + given or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a c + coefficient sampled from a uniform distribution with bounds [1/blur_range, blur_range]. + If None, no randomisation. Default is 1.15. + """ + + # prepare data files + self.labels_paths = utils.list_images_in_folder(labels_dir) + + # generation parameters + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ + utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + self.n_channels = n_channels + if generation_labels is not None: + self.generation_labels = utils.load_array_if_path(generation_labels) + else: + self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir) + if output_labels is not None: + self.output_labels = utils.load_array_if_path(output_labels) + else: + self.output_labels = self.generation_labels + self.target_res = utils.load_array_if_path(target_res) + self.batchsize = batchsize + # preliminary operations + self.output_shape = utils.load_array_if_path(output_shape) + self.output_div_by_n = output_div_by_n + # GMM parameters + self.prior_distributions = prior_distributions + if generation_classes is not None: + self.generation_classes = utils.load_array_if_path(generation_classes) + assert self.generation_classes.shape == self.generation_labels.shape, \ + 'if provided, generation labels should have the same shape as generation_labels' + unique_classes = np.unique(self.generation_classes) + assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ + 'generation_classes should a linear range between 0 and its maximum value.' + else: + self.generation_classes = np.arange(self.generation_labels.shape[0]) + self.prior_means = utils.load_array_if_path(prior_means) + self.prior_stds = utils.load_array_if_path(prior_stds) + self.use_specific_stats_for_channel = use_specific_stats_for_channel + + # blurring parameters + self.blur_range = blur_range + + # build transformation model + self.labels_to_image_model, self.model_output_shape = self._build_lab2im_model() + + # build generator for model inputs + self.model_inputs_generator = self._build_model_inputs(len(self.generation_labels)) + + # build brain generator + self.image_generator = self._build_image_generator() + + def _build_lab2im_model(self): + # build_model + lab_to_im_model = lab2im_model(labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + blur_range=self.blur_range) + out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] + return lab_to_im_model, out_shape + + def _build_image_generator(self): + while True: + model_inputs = next(self.model_inputs_generator) + [image, labels] = self.labels_to_image_model.predict(model_inputs) + yield image, labels + + def generate_image(self): + """call this method when an object of this class has been instantiated to generate new brains""" + (image, labels) = next(self.image_generator) + # put back images in native space + list_images = list() + list_labels = list() + for i in range(self.batchsize): + list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), aff_ref=self.aff, + n_dims=self.n_dims)) + list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), aff_ref=self.aff, + n_dims=self.n_dims)) + image = np.stack(list_images, axis=0) + labels = np.stack(list_labels, axis=0) + return np.squeeze(image), np.squeeze(labels) + + def _build_model_inputs(self, n_labels): + + # get label info + _, _, n_dims, _, _, _ = utils.get_volume_info(self.labels_paths[0]) + + # Generate! + while True: + + # randomly pick as many images as batchsize + indices = npr.randint(len(self.labels_paths), size=self.batchsize) + + # initialise input lists + list_label_maps = [] + list_means = [] + list_stds = [] + + for idx in indices: + + # load label in identity space, and add them to inputs + y = utils.load_volume(self.labels_paths[idx], dtype='int', aff_ref=np.eye(4)) + list_label_maps.append(utils.add_axis(y, axis=[0, -1])) + + # add means and standard deviations to inputs + means = np.empty((1, n_labels, 0)) + stds = np.empty((1, n_labels, 0)) + for channel in range(self.n_channels): + + # retrieve channel specific stats if necessary + if isinstance(self.prior_means, np.ndarray): + if (self.prior_means.shape[0] > 2) & self.use_specific_stats_for_channel: + if self.prior_means.shape[0] / 2 != self.n_channels: + raise ValueError("the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True.") + tmp_prior_means = self.prior_means[2 * channel:2 * channel + 2, :] + else: + tmp_prior_means = self.prior_means + else: + tmp_prior_means = self.prior_means + if isinstance(self.prior_stds, np.ndarray): + if (self.prior_stds.shape[0] > 2) & self.use_specific_stats_for_channel: + if self.prior_stds.shape[0] / 2 != self.n_channels: + raise ValueError("the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True.") + tmp_prior_stds = self.prior_stds[2 * channel:2 * channel + 2, :] + else: + tmp_prior_stds = self.prior_stds + else: + tmp_prior_stds = self.prior_stds + + # draw means and std devs from priors + tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels, + self.prior_distributions, 125., 100., + positive_only=True) + tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels, + self.prior_distributions, 15., 10., + positive_only=True) + tmp_means = utils.add_axis(tmp_classes_means[self.generation_classes], axis=[0, -1]) + tmp_stds = utils.add_axis(tmp_classes_stds[self.generation_classes], axis=[0, -1]) + means = np.concatenate([means, tmp_means], axis=-1) + stds = np.concatenate([stds, tmp_stds], axis=-1) + list_means.append(means) + list_stds.append(stds) + + # build list of inputs of augmentation model + list_inputs = [list_label_maps, list_means, list_stds] + if self.batchsize > 1: # concatenate individual input types if batchsize > 1 + list_inputs = [np.concatenate(item, 0) for item in list_inputs] + else: + list_inputs = [item[0] for item in list_inputs] + + yield list_inputs diff --git a/nobrainer/ext/lab2im/lab2im_model.py b/nobrainer/ext/lab2im/lab2im_model.py new file mode 100644 index 00000000..f32c5b5a --- /dev/null +++ b/nobrainer/ext/lab2im/lab2im_model.py @@ -0,0 +1,174 @@ +""" +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import numpy as np +import keras.layers as KL +from keras.models import Model + +# project imports +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im import layers +from nobrainer.ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling + + +def lab2im_model(labels_shape, + n_channels, + generation_labels, + output_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + blur_range=1.15): + """ + This function builds a keras/tensorflow model to generate images from provided label maps. + The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. + The model will take as inputs: + -a label map + -a vector containing the means of the Gaussian Mixture Model for each label, + -a vector containing the standard deviations of the Gaussian Mixture Model for each label, + -an array of size batch*(n_dims+1)*(n_dims+1) representing a linear transformation + The model returns: + -the generated image normalised between 0 and 1. + -the corresponding label map, with only the labels present in output_labels (the other are reset to zero). + :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array. + :param n_channels: number of channels to be synthetised. + :param generation_labels: list of all possible label values in the input label maps. + Can be a sequence or a 1d numpy array. + :param output_labels: list of the same length as generation_labels to indicate which values to use in the label maps + returned by this model, i.e. all occurrences of generation_labels[i] in the input label maps will be converted to + output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] if you wish to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + :param atlas_res: resolution of the input label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param target_res: target resolution of the generated images and corresponding label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param output_shape: (optional) desired shape of the output images. + If the atlas and target resolutions are the same, the output will be cropped to output_shape, and if the two + resolutions are different, the output will be resized with trilinear interpolation to output_shape. + Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape + if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. + :param blur_range: (optional) Randomise the standard deviation of the blurring kernels, (whether data_res is given + or not). At each mini_batch, the standard deviation of the blurring kernels are multiplied by a coefficient sampled + from a uniform distribution with bounds [1/blur_range, blur_range]. If None, no randomisation. Default is 1.15. + """ + + # reformat resolutions + labels_shape = utils.reformat_to_list(labels_shape) + n_dims, _ = utils.get_dims(labels_shape) + atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0] + target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + + # get shapes + crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + + # define model inputs + labels_input = KL.Input(shape=labels_shape+[1], name='labels_input', dtype='int32') + means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') + stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='stds_input') + + # deform labels + labels = layers.RandomSpatialDeformation(inter_method='nearest')(labels_input) + + # cropping + if crop_shape != labels_shape: + labels._keras_shape = tuple(labels.get_shape().as_list()) + labels = layers.RandomCrop(crop_shape)(labels) + + # build synthetic image + labels._keras_shape = tuple(labels.get_shape().as_list()) + image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + + # apply bias field + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.BiasFieldCorruption(.3, .025, same_bias_for_all_channels=False)(image) + + # intensity augmentation + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.2)(image) + + # blur image + sigma = blurring_sigma_for_downsampling(atlas_res, target_res) + image._keras_shape = tuple(image.get_shape().as_list()) + image = layers.GaussianBlur(sigma=sigma, random_blur_range=blur_range)(image) + + # resample to target res + if crop_shape != output_shape: + image = resample_tensor(image, output_shape, interp_method='linear') + labels = resample_tensor(labels, output_shape, interp_method='nearest') + + # reset unwanted labels to zero + labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + + # build model (dummy layer enables to keep the labels when plugging this model to other models) + image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) + brain_model = Model(inputs=[labels_input, means_input, stds_input], outputs=[image, labels]) + + return brain_model + + +def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n): + + n_dims = len(atlas_res) + + # get resampling factor + if atlas_res.tolist() != target_res.tolist(): + resample_factor = [atlas_res[i] / float(target_res[i]) for i in range(n_dims)] + else: + resample_factor = None + + # output shape specified, need to get cropping shape, and resample shape if necessary + if output_shape is not None: + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') + + # make sure that output shape is smaller or equal to label shape + if resample_factor is not None: + output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] + else: + output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] + + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape] + if output_shape != tmp_shape: + print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, + tmp_shape)) + output_shape = tmp_shape + + # get cropping and resample shape + if resample_factor is not None: + cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] + else: + cropping_shape = output_shape + + # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n + else: + cropping_shape = labels_shape + if resample_factor is not None: + output_shape = [int(np.around(cropping_shape[i]*resample_factor[i], 0)) for i in range(n_dims)] + else: + output_shape = cropping_shape + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, answer_type='closer') + for s in output_shape] + + return cropping_shape, output_shape diff --git a/nobrainer/ext/lab2im/layers.py b/nobrainer/ext/lab2im/layers.py new file mode 100644 index 00000000..e477d607 --- /dev/null +++ b/nobrainer/ext/lab2im/layers.py @@ -0,0 +1,2060 @@ +""" +This file regroups several custom keras layers used in the generation model: + - RandomSpatialDeformation, + - RandomCrop, + - RandomFlip, + - SampleConditionalGMM, + - SampleResolution, + - GaussianBlur, + - DynamicGaussianBlur, + - MimicAcquisition, + - BiasFieldCorruption, + - IntensityAugmentation, + - DiceLoss, + - WeightedL2Loss, + - ResetValuesToZero, + - ConvertLabels, + - PadAroundCentre, + - MaskEdges + - ImageGradients + - RandomDilationErosion + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import keras +import numpy as np +import tensorflow as tf +import keras.backend as K +from keras.layers import Layer + +# project imports +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im import edit_tensors as l2i_et + +# third-party imports +from nobrainer.ext.neuron import utils as nrn_utils +import nobrainer.ext.neuron.layers as nrn_layers + + +class RandomSpatialDeformation(Layer): + """This layer spatially deforms one or several tensors with a combination of affine and elastic transformations. + The input tensors are expected to have the same shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + The non-linear deformation is obtained by: + 1) a small-size SVF is sampled from a centred normal distribution of random standard deviation. + 2) it is resized with trilinear interpolation to half the shape of the input tensor + 3) it is integrated to obtain a diffeomorphic transformation + 4) finally, it is resized (again with trilinear interpolation) to full image size + :param scaling_bounds: (optional) range of the random scaling to apply. The scaling factor for each dimension is + sampled from a uniform distribution of predefined bounds. Can either be: + 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds + [1-scaling_bounds, 1+scaling_bounds] for each dimension. + 2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds + (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension. + 3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution + of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. + 4) False, in which case scaling is completely turned off. + Default is scaling_bounds = 0.15 (case 1) + :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1 + and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]]. + Default is rotation_bounds = 15. + :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. + :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we + encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator). + :param enable_90_rotations: (optional) whether to rotate the input by a random angle chosen in {0, 90, 180, 270}. + This is done regardless of the value of rotation_bounds. If true, a different value is sampled for each dimension. + :param nonlin_std: (optional) maximum value of the standard deviation of the normal distribution from which we + sample the small-size SVF. Set to 0 if you wish to completely turn the elastic deformation off. + :param nonlin_scale: (optional) if nonlin_std is not False, factor between the shapes of the input tensor + and the shape of the input non-linear tensor. + :param inter_method: (optional) interpolation method when deforming the input tensor. Can be 'linear', or 'nearest' + :param prob_deform: (optional) probability to apply spatial deformation + """ + + def __init__(self, + scaling_bounds=0.15, + rotation_bounds=10, + shearing_bounds=0.02, + translation_bounds=False, + enable_90_rotations=False, + nonlin_std=4., + nonlin_scale=.0625, + inter_method='linear', + prob_deform=1, + **kwargs): + + # shape attributes + self.n_inputs = 1 + self.inshape = None + self.n_dims = None + self.small_shape = None + + # deformation attributes + self.scaling_bounds = scaling_bounds + self.rotation_bounds = rotation_bounds + self.shearing_bounds = shearing_bounds + self.translation_bounds = translation_bounds + self.enable_90_rotations = enable_90_rotations + self.nonlin_std = nonlin_std + self.nonlin_scale = nonlin_scale + + # boolean attributes + self.apply_affine_trans = (self.scaling_bounds is not False) | (self.rotation_bounds is not False) | \ + (self.shearing_bounds is not False) | (self.translation_bounds is not False) | \ + self.enable_90_rotations + self.apply_elastic_trans = self.nonlin_std > 0 + self.prob_deform = prob_deform + + # interpolation methods + self.inter_method = inter_method + + super(RandomSpatialDeformation, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["scaling_bounds"] = self.scaling_bounds + config["rotation_bounds"] = self.rotation_bounds + config["shearing_bounds"] = self.shearing_bounds + config["translation_bounds"] = self.translation_bounds + config["enable_90_rotations"] = self.enable_90_rotations + config["nonlin_std"] = self.nonlin_std + config["nonlin_scale"] = self.nonlin_scale + config["inter_method"] = self.inter_method + config["prob_deform"] = self.prob_deform + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + inputshape = [input_shape] + else: + self.n_inputs = len(input_shape) + inputshape = input_shape + self.inshape = inputshape[0][1:] + self.n_dims = len(self.inshape) - 1 + + if self.apply_elastic_trans: + self.small_shape = utils.get_resample_shape(self.inshape[:self.n_dims], + self.nonlin_scale, self.n_dims) + else: + self.small_shape = None + + self.inter_method = utils.reformat_to_list(self.inter_method, length=self.n_inputs, dtype='str') + + self.built = True + super(RandomSpatialDeformation, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # reformat inputs and get its shape + if self.n_inputs < 2: + inputs = [inputs] + types = [v.dtype for v in inputs] + inputs = [tf.cast(v, dtype='float32') for v in inputs] + batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] + + # initialise list of transforms to operate + list_trans = list() + + # add affine deformation to inputs list + if self.apply_affine_trans: + affine_trans = utils.sample_affine_transform(batchsize, + self.n_dims, + self.rotation_bounds, + self.scaling_bounds, + self.shearing_bounds, + self.translation_bounds, + self.enable_90_rotations) + list_trans.append(affine_trans) + + # prepare non-linear deformation field and add it to inputs list + if self.apply_elastic_trans: + + # sample small field from normal distribution of specified std dev + trans_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_shape, dtype='int32')], axis=0) + trans_std = tf.random.uniform((1, 1), maxval=self.nonlin_std) + elastic_trans = tf.random.normal(trans_shape, stddev=trans_std) + + # reshape this field to half size (for smoother SVF), integrate it, and reshape to full image size + resize_shape = [max(int(self.inshape[i] / 2), self.small_shape[i]) for i in range(self.n_dims)] + elastic_trans = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans) + elastic_trans = nrn_layers.VecInt()(elastic_trans) + elastic_trans = nrn_layers.Resize(size=self.inshape[:self.n_dims], interp_method='linear')(elastic_trans) + list_trans.append(elastic_trans) + + # apply deformations and return tensors with correct dtype + if self.apply_affine_trans | self.apply_elastic_trans: + if self.prob_deform == 1: + inputs = [nrn_layers.SpatialTransformer(m)([v] + list_trans) for (m, v) in + zip(self.inter_method, inputs)] + else: + rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_deform)) + inputs = [K.switch(rand_trans, nrn_layers.SpatialTransformer(m)([v] + list_trans), v) + for (m, v) in zip(self.inter_method, inputs)] + if self.n_inputs < 2: + return tf.cast(inputs[0], types[0]) + else: + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + + +class RandomCrop(Layer): + """Randomly crop all input tensors to a given shape. This cropping is applied to all channels. + The input tensors are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param crop_shape: list with cropping shape in each dimension (excluding batch and channel dimension) + + example: + if input is a tensor of shape [batchsize, 160, 160, 160, 3], + output = RandomCrop(crop_shape=[96, 128, 96])(input) + will yield an output of shape [batchsize, 96, 128, 96, 3] that is obtained by cropping with randomly selected + cropping indices. + """ + + def __init__(self, crop_shape, **kwargs): + + self.several_inputs = True + self.crop_max_val = None + self.crop_shape = crop_shape + self.n_dims = len(crop_shape) + self.list_n_channels = None + super(RandomCrop, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["crop_shape"] = self.crop_shape + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + self.several_inputs = False + inputshape = [input_shape] + else: + inputshape = input_shape + self.crop_max_val = np.array(np.array(inputshape[0][1:self.n_dims + 1])) - np.array(self.crop_shape) + self.list_n_channels = [i[-1] for i in inputshape] + self.built = True + super(RandomCrop, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # if one input only is provided, performs the cropping directly + if not self.several_inputs: + return tf.map_fn(self._single_slice, inputs, dtype=inputs.dtype) + + # otherwise we concatenate all inputs before cropping, so that they are all cropped at the same location + else: + types = [v.dtype for v in inputs] + inputs = tf.concat([tf.cast(v, 'float32') for v in inputs], axis=-1) + inputs = tf.map_fn(self._single_slice, inputs, dtype=tf.float32) + inputs = tf.split(inputs, self.list_n_channels, axis=-1) + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + + def _single_slice(self, vol): + crop_idx = tf.cast(tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), 'float32'), dtype='int32') + crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype='int32')], axis=0) + crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype='int32') + return tf.slice(vol, begin=crop_idx, size=crop_size) + + def compute_output_shape(self, input_shape): + output_shape = [tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels] + return output_shape if self.several_inputs else output_shape[0] + + +class RandomFlip(Layer): + """This layer randomly flips the input tensor along the specified axes with a specified probability. + It can also take multiple tensors as inputs (if they have the same shape). The same flips will be applied to all + input tensors. These are expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + If specified, this layer can also swap corresponding values. This is especially useful when flipping label maps + with different labels for right/left structures, such that the flipped label maps keep a consistent labelling. + :param axis: integer, or list of integers specifying the dimensions along which to flip. + If a list, the input tensors can be flipped simultaneously in several directions. The values in flip_axis exclude + the batch dimension (e.g. 0 will flip the tensor along the first axis after the batch dimension). + Default is None, where the tensors can be flipped along all axes (except batch and channel axes). + :param swap_labels: boolean to specify whether to swap the values of each input. Values are only swapped if an odd + number of flips is applied. + Can also be a list if several tensors are given as input. + All the inputs for which the values need to be swapped must be int32 or int64. + :param label_list: if swap_labels is True, list of all labels contained in labels. Must be ordered as follows, first + the neutral labels (i.e. non-sided), then left labels and right labels. + :param n_neutral_labels: if swap_labels is True, number of non-sided labels + :param prob: probability to flip along each specified axis + + example 1: + if input is a tensor of shape (batchsize, 10, 100, 200, 3) + output = RandomFlip()(input) will randomly flip input along one of the 1st, 2nd, or 3rd axis (i.e. those with shape + 10, 100, 200). + + example 2: + if input is a tensor of shape (batchsize, 10, 100, 200, 3) + output = RandomFlip(flip_axis=1)(input) will randomly flip input along the 3rd axis (with shape 100), i.e. the axis + with index 1 if we don't count the batch axis. + + example 3: + input = tf.convert_to_tensor(np.array([[1, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 2, 2, 0], + [1, 0, 0, 0, 0, 0, 0]])) + label_list = np.array([0, 1, 2]) + n_neutral_labels = 1 + output = RandomFlip(flip_axis=1, swap_labels=True, label_list=label_list, n_neutral_labels=n_neutral_labels)(input) + where output will either be equal to input (bear in mind the flipping occurs with a 0.5 probability), or: + output = [[0, 0, 0, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 1, 1, 0, 0, 0, 2], + [0, 0, 0, 0, 0, 0, 2]] + Note that the input must have a dtype int32 or int64 for its values to be swapped, otherwise an error will be raised + + example 4: + if labels is the same as in the input of example 3, and image is a float32 image, then we can swap consistently both + the labels and the image with: + labels, image = RandomFlip(flip_axis=1, swap_labels=[True, False], label_list=label_list, + n_neutral_labels=n_neutral_labels)([labels, image]]) + Note that the labels must have a dtype int32 or int64 to be swapped, otherwise an error will be raised. + This doesn't concern the image input, as its values are not swapped. + """ + + def __init__(self, axis=None, swap_labels=False, label_list=None, n_neutral_labels=None, prob=0.5, **kwargs): + + # shape attributes + self.several_inputs = True + self.n_dims = None + self.list_n_channels = None + + # axis along which to flip + self.axis = utils.reformat_to_list(axis) + self.flip_axes = None + + # whether to swap labels, and corresponding label list + self.swap_labels = utils.reformat_to_list(swap_labels) + self.label_list = label_list + self.n_neutral_labels = n_neutral_labels + self.swap_lut = None + + self.prob = prob + + super(RandomFlip, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["axis"] = self.axis + config["swap_labels"] = self.swap_labels + config["label_list"] = self.label_list + config["n_neutral_labels"] = self.n_neutral_labels + config["prob"] = self.prob + return config + + def build(self, input_shape): + + if not isinstance(input_shape, list): + self.several_inputs = False + inputshape = [input_shape] + else: + inputshape = input_shape + self.n_dims = len(inputshape[0][1:-1]) + self.list_n_channels = [i[-1] for i in inputshape] + self.swap_labels = utils.reformat_to_list(self.swap_labels, length=len(inputshape)) + self.flip_axes = np.arange(self.n_dims).tolist() if self.axis is None else self.axis + + # create label list with swapped labels + if any(self.swap_labels): + assert (self.label_list is not None) & (self.n_neutral_labels is not None), \ + 'please provide a label_list, and n_neutral_labels when swapping the values of at least one input' + n_labels = len(self.label_list) + if self.n_neutral_labels == n_labels: + self.swap_labels = [False] * len(self.swap_labels) + else: + rl_split = np.split(self.label_list, [self.n_neutral_labels, + self.n_neutral_labels + int((n_labels-self.n_neutral_labels)/2)]) + label_list_swap = np.concatenate((rl_split[0], rl_split[2], rl_split[1])) + swap_lut = utils.get_mapping_lut(self.label_list, label_list_swap) + self.swap_lut = tf.convert_to_tensor(swap_lut, dtype='int32') + + self.built = True + super(RandomFlip, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # convert inputs to list, and get each input type + inputs = [inputs] if not self.several_inputs else inputs + types = [v.dtype for v in inputs] + + # store whether to flip along each specified dimension + batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] + size = tf.concat([batchsize, len(self.flip_axes) * tf.ones(1, dtype='int32')], axis=0) + rand_flip = K.less(tf.random.uniform(size, 0, 1), self.prob) + + # swap right/left labels if we apply an odd number of flips + odd = tf.math.floormod(tf.reduce_sum(tf.cast(rand_flip, 'int32'), -1, keepdims=True), 2) != 0 + swapped_inputs = list() + for i in range(len(inputs)): + if self.swap_labels[i]: + swapped_inputs.append(tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i])) + else: + swapped_inputs.append(inputs[i]) + + # flip inputs and convert them back to their original type + inputs = tf.concat([tf.cast(v, 'float32') for v in swapped_inputs], axis=-1) + inputs = tf.map_fn(self._single_flip, [inputs, rand_flip], dtype=tf.float32) + inputs = tf.split(inputs, self.list_n_channels, axis=-1) + + if self.several_inputs: + return [tf.cast(v, t) for (t, v) in zip(types, inputs)] + else: + return tf.cast(inputs[0], types[0]) + + def _single_swap(self, inputs): + return K.switch(inputs[1], tf.gather(self.swap_lut, inputs[0]), inputs[0]) + + @staticmethod + def _single_flip(inputs): + flip_axis = tf.where(inputs[1]) + return K.switch(tf.equal(tf.size(flip_axis), 0), inputs[0], tf.reverse(inputs[0], axis=flip_axis[..., 0])) + + +class SampleConditionalGMM(Layer): + """This layer generates an image by sampling a Gaussian Mixture Model conditioned on a label map given as input. + The parameters of the GMM are given as two additional inputs to the layer (means and standard deviations): + image = SampleConditionalGMM(generation_labels)([label_map, means, stds]) + + :param generation_labels: list of all possible label values contained in the input label maps. + Must be a list or a 1D numpy array of size N, where N is the total number of possible label values. + + Layer inputs: + label_map: input label map of shape [batchsize, shape_dim1, ..., shape_dimn, n_channel]. + All the values of label_map must be contained in generation_labels, but the input label_map doesn't necessarily have + to contain all the values in generation_labels. + means: tensor containing the mean values of all Gaussian distributions of the GMM. + It must be of shape [batchsize, N, n_channel], and in the same order as generation label, + i.e. the ith value of generation_labels will be associated to the ith value of means. + stds: same as means but for the standard deviations of the GMM. + """ + + def __init__(self, generation_labels, **kwargs): + self.generation_labels = generation_labels + self.n_labels = None + self.n_channels = None + self.max_label = None + self.indices = None + self.shape = None + super(SampleConditionalGMM, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["generation_labels"] = self.generation_labels + return config + + def build(self, input_shape): + + # check n_labels and n_channels + assert len(input_shape) == 3, 'should have three inputs: labels, means, std devs (in that order).' + self.n_channels = input_shape[1][-1] + self.n_labels = len(self.generation_labels) + assert self.n_labels == input_shape[1][1], 'means should have the same number of values as generation_labels' + assert self.n_labels == input_shape[2][1], 'stds should have the same number of values as generation_labels' + + # scatter parameters (to build mean/std lut) + self.max_label = np.max(self.generation_labels) + 1 + indices = np.concatenate([self.generation_labels + self.max_label * i for i in range(self.n_channels)], axis=-1) + self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype='int32') + self.indices = tf.convert_to_tensor(utils.add_axis(indices, axis=[0, -1]), dtype='int32') + + self.built = True + super(SampleConditionalGMM, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # reformat labels and scatter indices + batch = tf.split(tf.shape(inputs[0]), [1, -1])[0] + tmp_indices = tf.tile(self.indices, tf.concat([batch, tf.convert_to_tensor([1, 1], dtype='int32')], axis=0)) + labels = tf.concat([tf.cast(inputs[0], dtype='int32') + self.max_label * i for i in range(self.n_channels)], -1) + + # build mean map + means = tf.concat([inputs[1][..., i] for i in range(self.n_channels)], 1) + tile_shape = tf.concat([batch, tf.convert_to_tensor([1, ], dtype='int32')], axis=0) + means = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape) + means_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32) + + # same for stds + stds = tf.concat([inputs[2][..., i] for i in range(self.n_channels)], 1) + stds = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape) + stds_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32) + + return stds_map * tf.random.normal(tf.shape(labels)) + means_map + + def compute_output_shape(self, input_shape): + return input_shape[0] if (self.n_channels == 1) else tuple(list(input_shape[0][:-1]) + [self.n_channels]) + + +class SampleResolution(Layer): + """Build a random resolution tensor by sampling a uniform distribution of provided range. + + You can use this layer in the following ways: + resolution = SampleConditionalGMM(min_resolution)() in this case resolution will be a tensor of shape (n_dims,), + where n_dims is the length of the min_resolution parameter (provided as a list, see below). + resolution = SampleConditionalGMM(min_resolution)(input), where input is a tensor for which the first dimension + represents the batch_size. In this case resolution will be a tensor of shape (batchsize, n_dims,). + + :param min_resolution: list of length n_dims specifying the inferior bounds of the uniform distributions to + sample from for each value. + :param max_res_iso: If not None, all the values of resolution will be equal to the same value, which is randomly + sampled at each minibatch in U(min_resolution, max_res_iso). + :param max_res_aniso: If not None, we first randomly select a direction i in the range [0, n_dims-1], and we sample + a value in the corresponding uniform distribution U(min_resolution[i], max_res_aniso[i]). + The other values of resolution will be set to min_resolution. + :param prob_iso: if both max_res_iso and max_res_aniso are specified, this allows to specify the probability of + sampling an isotropic resolution (therefore using max_res_iso) with respect to anisotropic resolution + (which would use max_res_aniso). + :param prob_min: if not zero, this allows to return with the specified probability an output resolution equal + to min_resolution. + :param return_thickness: if set to True, this layer will also return a thickness value of the same shape as + resolution, which will be sampled independently for each axis from the uniform distribution + U(min_resolution, resolution). + + """ + + def __init__(self, + min_resolution, + max_res_iso=None, + max_res_aniso=None, + prob_iso=0.1, + prob_min=0.05, + return_thickness=True, + **kwargs): + + self.min_res = min_resolution + self.max_res_iso_input = max_res_iso + self.max_res_iso = None + self.max_res_aniso_input = max_res_aniso + self.max_res_aniso = None + self.prob_iso = prob_iso + self.prob_min = prob_min + self.return_thickness = return_thickness + self.n_dims = len(self.min_res) + self.add_batchsize = False + self.min_res_tens = None + super(SampleResolution, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["min_resolution"] = self.min_res + config["max_res_iso"] = self.max_res_iso + config["max_res_aniso"] = self.max_res_aniso + config["prob_iso"] = self.prob_iso + config["prob_min"] = self.prob_min + config["return_thickness"] = self.return_thickness + return config + + def build(self, input_shape): + + # check maximum resolutions + assert ((self.max_res_iso_input is not None) | (self.max_res_aniso_input is not None)), \ + 'at least one of maximum isotropic or anisotropic resolutions must be provided, received none' + + # reformat resolutions as numpy arrays + self.min_res = np.array(self.min_res) + if self.max_res_iso_input is not None: + self.max_res_iso = np.array(self.max_res_iso_input) + assert len(self.min_res) == len(self.max_res_iso), \ + 'min and isotropic max resolution must have the same length, ' \ + 'had {0} and {1}'.format(self.min_res, self.max_res_iso) + if np.array_equal(self.min_res, self.max_res_iso): + self.max_res_iso = None + if self.max_res_aniso_input is not None: + self.max_res_aniso = np.array(self.max_res_aniso_input) + assert len(self.min_res) == len(self.max_res_aniso), \ + 'min and anisotropic max resolution must have the same length, ' \ + 'had {} and {}'.format(self.min_res, self.max_res_aniso) + if np.array_equal(self.min_res, self.max_res_aniso): + self.max_res_aniso = None + + # check prob iso + if (self.max_res_iso is not None) & (self.max_res_aniso is not None) & (self.prob_iso == 0): + raise Exception('prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled') + + if input_shape: + self.add_batchsize = True + + self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32') + + self.built = True + super(SampleResolution, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if not self.add_batchsize: + shape = [self.n_dims] + dim = tf.random.uniform(shape=(1, 1), minval=0, maxval=self.n_dims, dtype='int32') + mask = tf.tensor_scatter_nd_update(tf.zeros([self.n_dims], dtype='bool'), dim, + tf.convert_to_tensor([True], dtype='bool')) + else: + batch = tf.split(tf.shape(inputs), [1, -1])[0] + tile_shape = tf.concat([batch, tf.convert_to_tensor([1], dtype='int32')], axis=0) + self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape) + + shape = tf.concat([batch, tf.convert_to_tensor([self.n_dims], dtype='int32')], axis=0) + indices = tf.stack([tf.range(0, batch[0]), tf.random.uniform(batch, 0, self.n_dims, dtype='int32')], 1) + mask = tf.tensor_scatter_nd_update(tf.zeros(shape, dtype='bool'), indices, tf.ones(batch, dtype='bool')) + + # return min resolution as tensor if min=max + if (self.max_res_iso is None) & (self.max_res_aniso is None): + new_resolution = self.min_res_tens + + # sample isotropic resolution only + elif (self.max_res_iso is not None) & (self.max_res_aniso is None): + new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) + new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution_iso) + + # sample anisotropic resolution only + elif (self.max_res_iso is None) & (self.max_res_aniso is not None): + new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) + new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + tf.where(mask, new_resolution_aniso, self.min_res_tens)) + + # sample either anisotropic or isotropic resolution + else: + new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) + new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) + new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)), + new_resolution_iso, + tf.where(mask, new_resolution_aniso, self.min_res_tens)) + new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution) + + if self.return_thickness: + return [new_resolution, tf.random.uniform(tf.shape(self.min_res_tens), self.min_res_tens, new_resolution)] + else: + return new_resolution + + def compute_output_shape(self, input_shape): + if self.return_thickness: + return [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2 + else: + return (None, self.n_dims) if self.add_batchsize else self.n_dims + + +class GaussianBlur(Layer): + """Applies gaussian blur to an input image. + The input image is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param sigma: standard deviation of the blurring kernels to apply. Can be a number, a list of length n_dims, or + a numpy array. + :param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where + sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds + [1/random_blur_range, random_blur_range]. + :param use_mask: (optional) whether a mask of the input will be provided as an additional layer input. This is used + to mask the blurred image, and to correct for edge blurring effects. + + example 1: + output = GaussianBlur(sigma=0.5)(input) will isotropically blur the input with a gaussian kernel of std 0.5. + + example 2: + if input is a tensor of shape [batchsize, 10, 100, 200, 2] + output = GaussianBlur(sigma=[0.5, 1, 10])(input) will blur the input a different gaussian kernel in each dimension. + + example 3: + output = GaussianBlur(sigma=0.5, random_blur_range=1.15)(input) + will blur the input a different gaussian kernel in each dimension, as each dimension will be associated with + a kernel, whose standard deviation will be uniformly sampled from [0.5/1.15; 0.5*1.15]. + + example 4: + output = GaussianBlur(sigma=0.5, use_mask=True)([input, mask]) + will 1) blur the input a different gaussian kernel in each dimension, 2) mask the blurred image with the provided + mask, and 3) correct for edge blurring effects. If the provided mask is not of boolean type, it will be thresholded + above positive values. + """ + + def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs): + self.sigma = utils.reformat_to_list(sigma) + assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0' + self.use_mask = use_mask + + self.n_dims = None + self.n_channels = None + self.blur_range = random_blur_range + self.stride = None + self.separable = None + self.kernels = None + self.convnd = None + super(GaussianBlur, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["sigma"] = self.sigma + config["random_blur_range"] = self.blur_range + config["use_mask"] = self.use_mask + return config + + def build(self, input_shape): + + # get shapes + if self.use_mask: + assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True' + self.n_dims = len(input_shape[0]) - 2 + self.n_channels = input_shape[0][-1] + else: + self.n_dims = len(input_shape) - 2 + self.n_channels = input_shape[-1] + + # prepare blurring kernel + self.stride = [1] * (self.n_dims + 2) + self.sigma = utils.reformat_to_list(self.sigma, length=self.n_dims) + self.separable = np.linalg.norm(np.array(self.sigma)) > 5 + if self.blur_range is None: # fixed kernels + self.kernels = l2i_et.gaussian_kernel(self.sigma, separable=self.separable) + else: + self.kernels = None + + # prepare convolution + self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + + self.built = True + super(GaussianBlur, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if self.use_mask: + image = inputs[0] + mask = tf.cast(inputs[1], 'bool') + else: + image = inputs + mask = None + + # redefine the kernels at each new step when blur_range is activated + if self.blur_range is not None: + self.kernels = l2i_et.gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable) + + if self.separable: + for k in self.kernels: + if k is not None: + image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + if self.use_mask: + maskb = tf.cast(mask, 'float32') + maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + image = image / (maskb + K.epsilon()) + image = tf.where(mask, image, tf.zeros_like(image)) + else: + if any(self.sigma): + image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + if self.use_mask: + maskb = tf.cast(mask, 'float32') + maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + image = image / (maskb + K.epsilon()) + image = tf.where(mask, image, tf.zeros_like(image)) + + return image + + +class DynamicGaussianBlur(Layer): + """Applies gaussian blur to an input image, where the standard deviation of the blurring kernel is provided as a + layer input, which enables to perform dynamic blurring (i.e. the blurring kernel can vary at each minibatch). + :param max_sigma: maximum value of the standard deviation that will be provided as input. This is used to compute + the size of the blurring kernels. It must be provided as a list of length n_dims. + :param random_blur_range: (optional) if not None, this introduces a randomness in the blurring kernels, where + sigma is now multiplied by a coefficient dynamically sampled from a uniform distribution with bounds + [1/random_blur_range, random_blur_range]. + + example: + blurred_image = DynamicGaussianBlur(max_sigma=[5.]*3, random_blurring_range=1.15)([image, sigma]) + will return a blurred version of image, where the standard deviation of each dimension (given as a tensor, and with + values lower than 5 for each axis) is multiplied by a random coefficient uniformly sampled from [1/1.15; 1.15]. + """ + + def __init__(self, max_sigma, random_blur_range=None, **kwargs): + self.max_sigma = max_sigma + self.n_dims = None + self.n_channels = None + self.convnd = None + self.blur_range = random_blur_range + self.separable = None + super(DynamicGaussianBlur, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["max_sigma"] = self.max_sigma + config["random_blur_range"] = self.blur_range + return config + + def build(self, input_shape): + assert len(input_shape) == 2, 'sigma should be provided as an input tensor for dynamic blurring' + self.n_dims = len(input_shape[0]) - 2 + self.n_channels = input_shape[0][-1] + self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.max_sigma = utils.reformat_to_list(self.max_sigma, length=self.n_dims) + self.separable = np.linalg.norm(np.array(self.max_sigma)) > 5 + self.built = True + super(DynamicGaussianBlur, self).build(input_shape) + + def call(self, inputs, **kwargs): + image = inputs[0] + sigma = inputs[-1] + kernels = l2i_et.gaussian_kernel(sigma, self.max_sigma, self.blur_range, self.separable) + if self.separable: + for kernel in kernels: + image = tf.map_fn(self._single_blur, [image, kernel], dtype=tf.float32) + else: + image = tf.map_fn(self._single_blur, [image, kernels], dtype=tf.float32) + return image + + def _single_blur(self, inputs): + if self.n_channels > 1: + split_channels = tf.split(inputs[0], [1] * self.n_channels, axis=-1) + blurred_channel = list() + for channel in split_channels: + blurred = self.convnd(tf.expand_dims(channel, 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') + blurred_channel.append(tf.squeeze(blurred, axis=0)) + output = tf.concat(blurred_channel, -1) + else: + output = self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') + output = tf.squeeze(output, axis=0) + return output + + +class MimicAcquisition(Layer): + """ + Layer that takes an image as input, and simulates data that has been acquired at low resolution. + The output is obtained by resampling the input twice: + - first at a resolution given as an input (i.e. the "acquisition" resolution), + - then at the output resolution (specified output shape). + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param volume_res: resolution of the provided inputs. Must be a 1-D numpy array with n_dims elements. + :param min_subsample_res: lower bound of the acquisition resolutions to mimic (i.e. the input resolution must have + values higher than min-subsample_res). + :param resample_shape: shape of the output tensor + :param build_dist_map: whether to return distance maps as outputs. These indicate the distance between each voxel + and the nearest non-interpolated voxel (during the second resampling). + :param prob_noise: probability to apply noise injection + + example 1: + im_res = [1., 1., 1.] + low_res = [1., 1., 3.] + res = tf.convert_to_tensor([1., 1., 4.5]) + image is a tensor of shape (None, 256, 256, 256, 3) + resample_shape = [256, 256, 256] + output = MimicAcquisition(im_res, low_res, resample_shape)([image, res]) + output will be a tensor of shape (None, 256, 256, 256, 3), obtained by downsampling image to [1., 1., 4.5]. + and re-upsampling it at initial resolution (because resample_shape is equal to the input shape). In this example all + examples of the batch will be downsampled to the same resolution (because res has no batch dimension). + Note that the provided res must have higher values than min_low_res. + + example 2: + im_res = [1., 1., 1.] + min_low_res = [1., 1., 1.] + res is a tensor of shape (None, 3), obtained for example by using the SampleResolution layer (see above). + image is a tensor of shape (None, 256, 256, 256, 1) + resample_shape = [128, 128, 128] + output = MimicAcquisition(im_res, low_res, resample_shape)([image, res]) + output will be a tensor of shape (None, 128, 128, 128, 1), obtained by downsampling each examples of the batch to + the matching resolution in res, and resampling them all to half the initial resolution. + Note that the provided res must have higher values than min_low_res. + """ + + def __init__(self, volume_res, min_subsample_res, resample_shape, build_dist_map=False, + noise_std=0, prob_noise=0.95, **kwargs): + + # resolutions and dimensions + self.volume_res = volume_res + self.min_subsample_res = min_subsample_res + self.n_dims = len(self.volume_res) + self.n_channels = None + self.add_batchsize = None + + # noise + self.noise_std = noise_std + self.prob_noise = prob_noise + + # input and output shapes + self.inshape = None + self.resample_shape = resample_shape + + # meshgrids for resampling + self.down_grid = None + self.up_grid = None + + # whether to return a map indicating the distance from the interpolated voxels, to acquired ones. + self.build_dist_map = build_dist_map + + super(MimicAcquisition, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["volume_res"] = self.volume_res + config["min_subsample_res"] = self.min_subsample_res + config["resample_shape"] = self.resample_shape + config["build_dist_map"] = self.build_dist_map + config["noise_std"] = self.noise_std + config["prob_noise"] = self.prob_noise + return config + + def build(self, input_shape): + + # set up input shape and acquisition shape + self.inshape = input_shape[0][1:] + self.n_channels = input_shape[0][-1] + self.add_batchsize = False if (input_shape[1][0] is None) else True + down_tensor_shape = np.int32(np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res) + + # build interpolation meshgrids + self.down_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0) + self.up_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0) + + self.built = True + super(MimicAcquisition, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # sort inputs + assert len(inputs) == 2, 'inputs must have two items, the tensor to resample, and the downsampling resolution' + vol = inputs[0] + subsample_res = tf.cast(inputs[1], dtype='float32') + vol = K.reshape(vol, [-1, *self.inshape]) # necessary for multi_gpu models + batchsize = tf.split(tf.shape(vol), [1, -1])[0] + tile_shape = tf.concat([batchsize, tf.ones([1], dtype='int32')], 0) + + # get downsampling and upsampling factors + if self.add_batchsize: + subsample_res = tf.tile(tf.expand_dims(subsample_res, 0), tile_shape) + down_shape = tf.cast(tf.convert_to_tensor(np.array(self.inshape[:-1]) * self.volume_res, dtype='float32') / + subsample_res, dtype='int32') + down_zoom_factor = tf.cast(down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype='float32') + up_zoom_factor = tf.cast(tf.convert_to_tensor(self.resample_shape, dtype='int32') / down_shape, dtype='float32') + + # downsample + down_loc = tf.tile(self.down_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], 0)) + down_loc = tf.cast(down_loc, 'float32') / l2i_et.expand_dims(down_zoom_factor, axis=[1] * self.n_dims) + inshape_tens = tf.tile(tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape) + inshape_tens = l2i_et.expand_dims(inshape_tens, axis=[1] * self.n_dims) + down_loc = K.clip(down_loc, 0., tf.cast(inshape_tens, 'float32')) + vol = tf.map_fn(self._single_down_interpn, [vol, down_loc], tf.float32) + + # add noise with predefined probability + if self.noise_std > 0: + sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32'), + self.n_channels * tf.ones([1], dtype='int32')], 0) + noise = tf.random.normal(tf.shape(vol), stddev=tf.random.uniform(sample_shape, maxval=self.noise_std)) + if self.prob_noise == 1: + vol += noise + else: + vol = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), vol + noise, vol) + + # upsample + up_loc = tf.tile(self.up_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], axis=0)) + up_loc = tf.cast(up_loc, 'float32') / l2i_et.expand_dims(up_zoom_factor, axis=[1] * self.n_dims) + vol = tf.map_fn(self._single_up_interpn, [vol, up_loc], tf.float32) + + # return upsampled volume + if not self.build_dist_map: + return vol + + # return upsampled volumes with distance maps + else: + + # get grid points + floor = tf.math.floor(up_loc) + ceil = tf.math.ceil(up_loc) + + # get distances of every voxel to higher and lower grid points for every dimension + f_dist = up_loc - floor + c_dist = ceil - up_loc + + # keep minimum 1d distances, and compute 3d distance to nearest grid point + dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims(subsample_res, axis=[1] * self.n_dims) + dist = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True)) + + return [vol, dist] + + @staticmethod + def _single_down_interpn(inputs): + return nrn_utils.interpn(inputs[0], inputs[1], interp_method='nearest') + + @staticmethod + def _single_up_interpn(inputs): + return nrn_utils.interpn(inputs[0], inputs[1], interp_method='linear') + + def compute_output_shape(self, input_shape): + output_shape = tuple([None] + self.resample_shape + [input_shape[0][-1]]) + return [output_shape] * 2 if self.build_dist_map else output_shape + + +class BiasFieldCorruption(Layer): + """This layer applies a smooth random bias field to the input by applying the following steps: + 1) we first sample a value for the standard deviation of a centred normal distribution + 2) a small-size SVF is sampled from this normal distribution + 3) the small SVF is then resized with trilinear interpolation to image size + 4) it is rescaled to positive values by taking the voxel-wise exponential + 5) it is multiplied to the input tensor. + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param bias_field_std: maximum value of the standard deviation sampled in 1 (it will be sampled from the range + [0, bias_field_std]) + :param bias_scale: ratio between the shape of the input tensor and the shape of the sampled SVF. + :param same_bias_for_all_channels: whether to apply the same bias field to all the channels of the input tensor. + :param prob: probability to apply this bias field corruption. + """ + + def __init__(self, bias_field_std=.5, bias_scale=.025, same_bias_for_all_channels=False, prob=0.95, **kwargs): + + # input shape + self.several_inputs = False + self.inshape = None + self.n_dims = None + self.n_channels = None + + # sampling shape + self.std_shape = None + self.small_bias_shape = None + + # bias field parameters + self.bias_field_std = bias_field_std + self.bias_scale = bias_scale + self.same_bias_for_all_channels = same_bias_for_all_channels + self.prob = prob + + super(BiasFieldCorruption, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["bias_field_std"] = self.bias_field_std + config["bias_scale"] = self.bias_scale + config["same_bias_for_all_channels"] = self.same_bias_for_all_channels + config["prob"] = self.prob + return config + + def build(self, input_shape): + + # input shape + if isinstance(input_shape, list): + self.several_inputs = True + self.inshape = input_shape + else: + self.inshape = [input_shape] + self.n_dims = len(self.inshape[0]) - 2 + self.n_channels = self.inshape[0][-1] + + # sampling shapes + self.std_shape = [1] * (self.n_dims + 1) + self.small_bias_shape = utils.get_resample_shape(self.inshape[0][1:self.n_dims + 1], self.bias_scale, 1) + if not self.same_bias_for_all_channels: + self.std_shape[-1] = self.n_channels + self.small_bias_shape[-1] = self.n_channels + + self.built = True + super(BiasFieldCorruption, self).build(input_shape) + + def call(self, inputs, **kwargs): + + if not self.several_inputs: + inputs = [inputs] + + if self.bias_field_std > 0: + + # sampling shapes + batchsize = tf.split(tf.shape(inputs[0]), [1, -1])[0] + std_shape = tf.concat([batchsize, tf.convert_to_tensor(self.std_shape, dtype='int32')], 0) + bias_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype='int32')], axis=0) + + # sample small bias field + bias_field = tf.random.normal(bias_shape, stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std)) + + # resize bias field and take exponential + bias_field = nrn_layers.Resize(size=self.inshape[0][1:self.n_dims + 1], interp_method='linear')(bias_field) + bias_field = tf.math.exp(bias_field) + + # apply bias field with predefined probability + if self.prob == 1: + return [tf.math.multiply(bias_field, v) for v in inputs] + else: + rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob)) + if self.several_inputs: + return [K.switch(rand_trans, tf.math.multiply(bias_field, v), v) for v in inputs] + else: + return K.switch(rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0]) + + else: + return inputs + + +class IntensityAugmentation(Layer): + """This layer enables to augment the intensities of the input tensor, as well as to apply min_max normalisation. + The following steps are applied (all are optional): + 1) white noise corruption, with a randomly sampled std dev. + 2) clip the input between two values + 3) min-max normalisation + 4) gamma augmentation (i.e. voxel-wise exponentiation by a randomly sampled power) + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + + :param noise_std: maximum value of the standard deviation of the Gaussian white noise used in 1 (it will be sampled + from the range [0, noise_std]). Set to 0 to skip this step. + :param clip: clip the input tensor between the given values. Can either be: a number (in which case we clip between + 0 and the given value), or a list or a numpy array with two elements. Default is 0, where no clipping occurs. + :param normalise: whether to apply min-max normalisation, to normalise between 0 and 1. Default is True. + :param norm_perc: percentiles (between 0 and 1) of the sorted intensity values for robust normalisation. Can be: + a number (in which case the robust minimum is the provided percentile of sorted values, and the maximum is the + 1 - norm_perc percentile), or a list/numpy array of 2 elements (percentiles for the minimum and maximum values). + The minimum and maximum values are computed separately for each channel if separate_channels is True. + Default is 0, where we simply take the minimum and maximum values. + :param gamma_std: standard deviation of the normal distribution from which we sample gamma (in log domain). + Default is 0, where no gamma augmentation occurs. + :param contrast_inversion: whether to perform contrast inversion (i.e. 1 - x). If True, this is performed randomly + for each element of the batch, as well as for each channel. + :param separate_channels: whether to augment all channels separately. Default is True. + :param prob_noise: probability to apply noise injection + :param prob_gamma: probability to apply gamma augmentation + """ + + def __init__(self, noise_std=0, clip=0, normalise=True, norm_perc=0, gamma_std=0, contrast_inversion=False, + separate_channels=True, prob_noise=0.95, prob_gamma=1, **kwargs): + + # shape attributes + self.n_dims = None + self.n_channels = None + self.flatten_shape = None + self.expand_minmax_dim = None + self.one = None + + # inputs + self.noise_std = noise_std + self.clip = clip + self.clip_values = None + self.normalise = normalise + self.norm_perc = norm_perc + self.perc = None + self.gamma_std = gamma_std + self.separate_channels = separate_channels + self.contrast_inversion = contrast_inversion + self.prob_noise = prob_noise + self.prob_gamma = prob_gamma + + super(IntensityAugmentation, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["noise_std"] = self.noise_std + config["clip"] = self.clip + config["normalise"] = self.normalise + config["norm_perc"] = self.norm_perc + config["gamma_std"] = self.gamma_std + config["separate_channels"] = self.separate_channels + config["prob_noise"] = self.prob_noise + config["prob_gamma"] = self.prob_gamma + return config + + def build(self, input_shape): + self.n_dims = len(input_shape) - 2 + self.n_channels = input_shape[-1] + self.flatten_shape = np.prod(np.array(input_shape[1:-1])) + self.flatten_shape = self.flatten_shape * self.n_channels if not self.separate_channels else self.flatten_shape + self.expand_minmax_dim = self.n_dims if self.separate_channels else self.n_dims + 1 + self.one = tf.ones([1], dtype='int32') + if self.clip: + self.clip_values = utils.reformat_to_list(self.clip) + self.clip_values = self.clip_values if len(self.clip_values) == 2 else [0, self.clip_values[0]] + else: + self.clip_values = None + if self.norm_perc: + self.perc = utils.reformat_to_list(self.norm_perc) + self.perc = self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]] + else: + self.perc = None + + self.built = True + super(IntensityAugmentation, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # prepare shape for sampling the noise and gamma std dev (depending on whether we augment channels separately) + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + if (self.noise_std > 0) | (self.gamma_std > 0) | self.contrast_inversion: + sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32')], 0) + if self.separate_channels: + sample_shape = tf.concat([sample_shape, self.n_channels * self.one], 0) + else: + sample_shape = tf.concat([sample_shape, self.one], 0) + else: + sample_shape = None + + # add noise with predefined probability + if self.noise_std > 0: + noise_stddev = tf.random.uniform(sample_shape, maxval=self.noise_std) + if self.separate_channels: + noise = tf.random.normal(tf.shape(inputs), stddev=noise_stddev) + else: + noise = tf.random.normal(tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev) + noise = tf.tile(noise, tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels])) + if self.prob_noise == 1: + inputs = inputs + noise + else: + inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), + inputs + noise, inputs) + + # clip images to given values + if self.clip_values is not None: + inputs = K.clip(inputs, self.clip_values[0], self.clip_values[1]) + + # normalise + if self.normalise: + # define robust min and max by sorting values and taking percentile + if self.perc is not None: + if self.separate_channels: + shape = tf.concat([batchsize, self.flatten_shape * self.one, self.n_channels * self.one], 0) + else: + shape = tf.concat([batchsize, self.flatten_shape * self.one], 0) + intensities = tf.sort(tf.reshape(inputs, shape), axis=1) + m = intensities[:, max(int(self.perc[0] * self.flatten_shape), 0), ...] + M = intensities[:, min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), ...] + # simple min and max + else: + m = K.min(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) + M = K.max(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) + # normalise + m = l2i_et.expand_dims(m, axis=[1] * self.expand_minmax_dim) + M = l2i_et.expand_dims(M, axis=[1] * self.expand_minmax_dim) + inputs = tf.clip_by_value(inputs, m, M) + inputs = (inputs - m) / (M - m + K.epsilon()) + + # apply voxel-wise exponentiation with predefined probability + if self.gamma_std > 0: + gamma = tf.random.normal(sample_shape, stddev=self.gamma_std) + if self.prob_gamma == 1: + inputs = tf.math.pow(inputs, tf.math.exp(gamma)) + else: + inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)), + tf.math.pow(inputs, tf.math.exp(gamma)), inputs) + + # apply random contrast inversion + if self.contrast_inversion: + rand_invert = tf.less(tf.random.uniform(sample_shape, maxval=1), 0.5) + split_channels = tf.split(inputs, [1] * self.n_channels, axis=-1) + split_rand_invert = tf.split(rand_invert, [1] * self.n_channels, axis=-1) + inverted_channel = list() + for (channel, invert) in zip(split_channels, split_rand_invert): + inverted_channel.append(tf.map_fn(self._single_invert, [channel, invert], dtype=channel.dtype)) + inputs = tf.concat(inverted_channel, -1) + + return inputs + + @staticmethod + def _single_invert(inputs): + return K.switch(tf.squeeze(inputs[1]), 1 - inputs[0], inputs[0]) + + +class DiceLoss(Layer): + """This layer computes the soft Dice loss between two tensors. + These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: dice_loss = DiceLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures + when computing the loss. Default is 0 where no boundary weighting is applied. + :param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels + within this distance to a region boundary. Default is 3. + :param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be + redundant when we have several labels. This is only used if boundary_weight is not 0. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__(self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs): + + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.boundary_weights = boundary_weights + self.boundary_dist = boundary_dist + self.skip_background = skip_background + self.enable_checks = enable_checks + self.spatial_axes = None + self.avg_pooling_layer = None + super(DiceLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["boundary_weights"] = self.boundary_weights + config["boundary_dist"] = self.boundary_dist + config["skip_background"] = self.skip_background + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' + assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) + self.skip_background = False if n_labels == 1 else self.skip_background + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) + + self.built = True + super(DiceLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] + pred = inputs[1] + if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps + gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) + pred = K.clip(pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) + + # compute dice loss for each label + top = 2 * gt * pred + bottom = tf.math.square(gt) + tf.math.square(pred) + + # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) + if self.boundary_weights: + avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) + boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') + if self.skip_background: + boundaries_channels = tf.unstack(boundaries, axis=-1) + boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) + boundary_weights_tensor = 1 + self.boundary_weights * boundaries + top *= boundary_weights_tensor + bottom *= boundary_weights_tensor + else: + boundary_weights_tensor = None + + # compute loss + top = tf.math.reduce_sum(top, self.spatial_axes) + bottom = tf.math.reduce_sum(bottom, self.spatial_axes) + dice = (top + tf.keras.backend.epsilon()) / (bottom + tf.keras.backend.epsilon()) + loss = 1 - dice + + # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). + if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt + if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes) + else: + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self. class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + loss = tf.reduce_sum(loss * self.class_weights_tens, -1) + + return tf.math.reduce_mean(loss) + + def compute_output_shape(self, input_shape): + return [[]] + + +class WeightedL2Loss(Layer): + """This layer computes a L2 loss weighted by a specified factor (target_value) between two tensors. + This is designed to be used on the layer before the softmax. + The tensors are expected to have the same shape [batchsize, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: wl2_loss = WeightedL2Loss()([gt, pred]) + + :param target_value: target value for the layer before softmax: target_value when gt = 1, -target_value when gt = 0. + """ + + def __init__(self, target_value=5, **kwargs): + self.target_value = target_value + self.n_labels = None + super(WeightedL2Loss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["target_value"] = self.target_value + return config + + def build(self, input_shape): + assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' + assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + self.n_labels = input_shape[0][-1] + self.built = True + super(WeightedL2Loss, self).build(input_shape) + + def call(self, inputs, **kwargs): + gt = inputs[0] + pred = inputs[1] + weights = tf.expand_dims(1 - gt[..., 0] + 1e-8, -1) + return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / (K.sum(weights) * self.n_labels) + + def compute_output_shape(self, input_shape): + return [[]] + + +class CrossEntropyLoss(Layer): + """This layer computes the cross-entropy loss between two tensors. + These tensors are expected to have the same shape (one-hot encoding) [batch, size_dim1, ..., size_dimN, n_labels]. + The first input tensor is the GT and the second is the prediction: ce_loss = CrossEntropyLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param boundary_weights: (optional) bonus weight that we apply to the voxels close to boundaries between structures + when computing the loss. Default is 0 where no boundary weighting is applied. + :param boundary_dist: (optional) if boundary_weight is not 0, the extra boundary weighting is applied to all voxels + within this distance to a region boundary. Default is 3. + :param skip_background: (optional) whether to skip boundary weighting for the background class, as this may be + redundant when we have several labels. This is only used if boundary_weight is not 0. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__(self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs): + + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.boundary_weights = boundary_weights + self.boundary_dist = boundary_dist + self.skip_background = skip_background + self.enable_checks = enable_checks + self.spatial_axes = None + self.avg_pooling_layer = None + super(CrossEntropyLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["boundary_weights"] = self.boundary_weights + config["boundary_dist"] = self.boundary_dist + config["skip_background"] = self.skip_background + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert len(input_shape) == 2, 'CrossEntropy expects 2 inputs to compute the Dice loss.' + assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) + self.skip_background = False if n_labels == 1 else self.skip_background + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, [0] * (1 + n_dims)) + + self.built = True + super(CrossEntropyLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] + pred = inputs[1] + if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps + gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) + pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) + pred = K.clip(pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()) # to avoid log(0) + + # compare prediction/target, ce has the same shape has the input tensors + ce = -gt * tf.math.log(pred) + + # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) + if self.boundary_weights: + avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) + boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') + if self.skip_background: + boundaries_channels = tf.unstack(boundaries, axis=-1) + boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) + boundary_weights_tensor = 1 + self.boundary_weights * boundaries + ce *= boundary_weights_tensor + else: + boundary_weights_tensor = None + + # apply class weighting across labels. By the end of this, ce still has the same shape has the input tensors. + if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt + if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes, True) + else: + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + ce = tf.reduce_sum(ce * self.class_weights_tens, -1) + + # sum along label axis, and take the mean along spatial dimensions + ce = tf.math.reduce_mean(tf.math.reduce_sum(ce, axis=-1)) + + return ce + + def compute_output_shape(self, input_shape): + return [[]] + + +class MomentLoss(Layer): + """This layer computes a moment loss between two tensors. Specifically, it computes the distance between the centres + of gravity for all the channels of the two tensors, and then returns a value averaged across all channels. + These tensors are expected to have the same shape [batch, size_dim1, ..., size_dimN, n_channels]. + The first input tensor is the GT and the second is the prediction: moment_loss = MomentLoss()([gt, pred]) + + :param class_weights: (optional) if given, the loss is obtained by a weighted average of the Dice across labels. + Must be a sequence or 1d numpy array of length n_labels. Can also be -1, where the weights are dynamically set to + the inverse of the volume of each label in the ground truth. + :param enable_checks: (optional) whether to make sure that the 2 input tensors are probabilistic (i.e. the label + probabilities sum to 1 at each voxel location). Default is True. + """ + + def __init__(self, class_weights=None, enable_checks=False, **kwargs): + self.class_weights = class_weights + self.dynamic_weighting = False + self.class_weights_tens = None + self.enable_checks = enable_checks + self.spatial_axes = None + self.coordinates = None + super(MomentLoss, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["class_weights"] = self.class_weights + config["enable_checks"] = self.enable_checks + return config + + def build(self, input_shape): + + # get shape + assert len(input_shape) == 2, 'MomentLoss expects 2 inputs to compute the Dice loss.' + assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + inshape = input_shape[0][1:] + n_dims = len(inshape[:-1]) + n_labels = inshape[-1] + self.spatial_axes = list(range(1, n_dims + 1)) + + # build coordinate meshgrid of size (1, dim1, dim2, ..., dimN, ndim, nchan) + self.coordinates = tf.stack(nrn_utils.volshape_to_ndgrid(inshape[:-1]), -1) + self.coordinates = tf.cast(l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), 'float32') + + # build tensor with class weights + if self.class_weights is not None: + if self.class_weights == -1: + self.dynamic_weighting = True + else: + class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) + + self.built = True + super(MomentLoss, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # make sure tensors are probabilistic + gt = inputs[0] # (B, dim1, dim2, ..., dimN, nchan) + pred = inputs[1] + if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps + gt = gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) + pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) + + # compute loss + gt_mean_coordinates = self._mean_coordinates(gt) # (B, ndim, nchan) + pred_mean_coordinates = self._mean_coordinates(pred) + loss = tf.math.sqrt(tf.reduce_sum(tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1)) # (B, nchan) + + # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). + if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt + self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) + if self.class_weights_tens is not None: + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + loss = tf.reduce_sum(loss * self.class_weights_tens, -1) + + return tf.math.reduce_mean(loss) + + def _mean_coordinates(self, tensor): + tensor = l2i_et.expand_dims(tensor, axis=-2) # (B, dim1, dim2, ..., dimN, 1, nchan) + numerator = tf.reduce_sum(tensor * self.coordinates, axis=self.spatial_axes) # (B, ndim, nchan) + denominator = tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon() + return numerator / denominator + + def compute_output_shape(self, input_shape): + return [[]] + + +class ResetValuesToZero(Layer): + """This layer enables to reset given values to 0 within the input tensors. + + :param values: list of values to be reset to 0. + + example: + input = tf.convert_to_tensor(np.array([[1, 0, 2, 2, 2, 2, 0], + [1, 3, 3, 3, 3, 3, 3], + [1, 0, 0, 0, 4, 4, 4]])) + values = [1, 3] + ResetValuesToZero(values)(input) + >> [[0, 0, 2, 2, 2, 2, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 4, 4]] + """ + + def __init__(self, values, **kwargs): + assert values is not None, 'please provide correct list of values, received None' + self.values = utils.reformat_to_list(values) + self.values_tens = None + self.n_values = len(values) + super(ResetValuesToZero, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["values"] = self.values + return config + + def build(self, input_shape): + self.values_tens = tf.convert_to_tensor(self.values) + self.built = True + super(ResetValuesToZero, self).build(input_shape) + + def call(self, inputs, **kwargs): + values = tf.cast(self.values_tens, dtype=inputs.dtype) + for i in range(self.n_values): + inputs = tf.where(tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs) + return inputs + + +class ConvertLabels(Layer): + """Convert all labels in a tensor by the corresponding given set of values. + labels_converted = ConvertLabels(source_values, dest_values)(labels). + labels must be an int32 tensor, and labels_converted will also be int32. + + :param source_values: list of all the possible values in labels. Must be a list or a 1D numpy array. + :param dest_values: list of all the target label values. Must be ordered the same as source values: + labels[labels == source_values[i]] = dest_values[i]. + If None (default), dest_values is equal to [0, ..., N-1], where N is the total number of values in source_values, + which enables to remap label maps to [0, ..., N-1]. + """ + + def __init__(self, source_values, dest_values=None, **kwargs): + self.source_values = source_values + self.dest_values = dest_values + self.lut = None + super(ConvertLabels, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["source_values"] = self.source_values + config["dest_values"] = self.dest_values + return config + + def build(self, input_shape): + self.lut = tf.convert_to_tensor(utils.get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32') + self.built = True + super(ConvertLabels, self).build(input_shape) + + def call(self, inputs, **kwargs): + return tf.gather(self.lut, tf.cast(inputs, dtype='int32')) + + +class PadAroundCentre(Layer): + """Pad the input tensor to the specified shape with the given value. + The input tensor is expected to have shape [batchsize, shape_dim1, ..., shape_dimn, channel]. + :param pad_margin: margin to use for padding. The tensor will be padded by the provided margin on each side. + Can either be a number (all axes padded with the same margin), or a list/numpy array of length n_dims. + example: if tensor is of shape [batch, x, y, z, n_channels] and margin=10, then the padded tensor will be of + shape [batch, x+2*10, y+2*10, z+2*10, n_channels]. + :param pad_shape: shape to pad the tensor to. Can either be a number (all axes padded to the same shape), or a + list/numpy array of length n_dims. + :param value: value to pad the tensors with. Default is 0. + """ + + def __init__(self, pad_margin=None, pad_shape=None, value=0, **kwargs): + self.pad_margin = pad_margin + self.pad_shape = pad_shape + self.value = value + self.pad_margin_tens = None + self.pad_shape_tens = None + self.n_dims = None + super(PadAroundCentre, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["pad_margin"] = self.pad_margin + config["pad_shape"] = self.pad_shape + config["value"] = self.value + return config + + def build(self, input_shape): + # input shape + self.n_dims = len(input_shape) - 2 + shape = list(input_shape) + shape[0] = 0 + shape[-1] = 0 + + if self.pad_margin is not None: + assert self.pad_shape is None, 'please do not provide a padding shape and margin at the same time.' + + # reformat padding margins + pad = np.transpose(np.array([[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] * 2)) + self.pad_margin_tens = tf.convert_to_tensor(pad, dtype='int32') + + elif self.pad_shape is not None: + assert self.pad_margin is None, 'please do not provide a padding shape and margin at the same time.' + + # pad shape + tensor_shape = tf.cast(tf.convert_to_tensor(shape), 'int32') + self.pad_shape_tens = np.array([0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0]) + self.pad_shape_tens = tf.convert_to_tensor(self.pad_shape_tens, dtype='int32') + self.pad_shape_tens = tf.math.maximum(tensor_shape, self.pad_shape_tens) + + # padding margin + min_margins = (self.pad_shape_tens - tensor_shape) / 2 + max_margins = self.pad_shape_tens - tensor_shape - min_margins + self.pad_margin_tens = tf.stack([min_margins, max_margins], axis=-1) + + else: + raise Exception('please either provide a padding shape or a padding margin.') + + self.built = True + super(PadAroundCentre, self).build(input_shape) + + def call(self, inputs, **kwargs): + return tf.pad(inputs, self.pad_margin_tens, mode='CONSTANT', constant_values=self.value) + + +class MaskEdges(Layer): + """Reset the edges of a tensor to zero (i.e. with bands of zeros along the specified axes). + The width of the zero-band is randomly drawn from a uniform distribution, whose range is given in boundaries. + + :param axes: axes along which to reset edges to zero. Can be an int (single axis), or a sequence. + :param boundaries: numpy array of shape (len(axes), 4). Each row contains the two bounds of the uniform + distributions from which we draw the width of the zero-bands on each side. + Those bounds must be expressed in relative side (i.e. between 0 and 1). + :return: a tensor of the same shape as the input, with bands of zeros along the specified axes. + + example: + tensor=tf.constant([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]) # shape = [1,10,10,1] + axes=1 + boundaries = np.array([[0.2, 0.45, 0.85, 0.9]]) + + In this case, we reset the edges along the 2nd dimension (i.e. the 1st dimension after the batch dimension), + the 1st zero-band will expand from the 1st row to a number drawn from [0.2*tensor.shape[1], 0.45*tensor.shape[1]], + and the 2nd zero-band will expand from a row drawn from [0.85*tensor.shape[1], 0.9*tensor.shape[1]], to the end of + the tensor. A possible output could be: + array([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]) # shape = [1,10,10,1] + """ + + def __init__(self, axes, boundaries, prob_mask=1, **kwargs): + self.axes = utils.reformat_to_list(axes, dtype='int') + self.boundaries = utils.reformat_to_n_channels_array(boundaries, n_dims=4, n_channels=len(self.axes)) + self.prob_mask = prob_mask + self.inputshape = None + super(MaskEdges, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["axes"] = self.axes + config["boundaries"] = self.boundaries + config["prob_mask"] = self.prob_mask + return config + + def build(self, input_shape): + self.inputshape = input_shape + self.built = True + super(MaskEdges, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # build mask + mask = tf.ones_like(inputs) + for i, axis in enumerate(self.axes): + + # select restricting indices + axis_boundaries = self.boundaries[i, :] + idx1 = tf.math.round(tf.random.uniform([1], + minval=axis_boundaries[0] * self.inputshape[axis], + maxval=axis_boundaries[1] * self.inputshape[axis])) + idx2 = tf.math.round(tf.random.uniform([1], + minval=axis_boundaries[2] * self.inputshape[axis], + maxval=axis_boundaries[3] * self.inputshape[axis] - 1) - idx1) + idx3 = self.inputshape[axis] - idx1 - idx2 + split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype='int32') + + # update mask + split_list = tf.split(inputs, split_idx, axis=axis) + tmp_mask = tf.concat([tf.zeros_like(split_list[0]), + tf.ones_like(split_list[1]), + tf.zeros_like(split_list[2])], axis=axis) + mask = mask * tmp_mask + + # mask second_channel + tensor = K.switch(tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), + inputs * mask, + inputs) + + return [tensor, mask] + + def compute_output_shape(self, input_shape): + return [input_shape] * 2 + + +class ImageGradients(Layer): + + def __init__(self, gradient_type='sobel', return_magnitude=False, **kwargs): + + self.gradient_type = gradient_type + assert (self.gradient_type == 'sobel') | (self.gradient_type == '1-step_diff'), \ + 'gradient_type should be either sobel or 1-step_diff, had %s' % self.gradient_type + + # shape + self.n_dims = 0 + self.shape = None + self.n_channels = 0 + + # convolution params if sobel diff + self.stride = None + self.kernels = None + self.convnd = None + + self.return_magnitude = return_magnitude + + super(ImageGradients, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["gradient_type"] = self.gradient_type + config["return_magnitude"] = self.return_magnitude + return config + + def build(self, input_shape): + + # get shapes + self.n_dims = len(input_shape) - 2 + self.shape = input_shape[1:] + self.n_channels = input_shape[-1] + + # prepare kernel if sobel gradients + if self.gradient_type == 'sobel': + self.kernels = l2i_et.sobel_kernels(self.n_dims) + self.stride = [1] * (self.n_dims + 2) + self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + else: + self.kernels = self.convnd = self.stride = None + + self.built = True + super(ImageGradients, self).build(input_shape) + + def call(self, inputs, **kwargs): + + image = inputs + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + gradients = list() + + # sobel method + if self.gradient_type == 'sobel': + # get sobel gradients in each direction + for n in range(self.n_dims): + gradient = image + # apply 1D kernel in each direction (sobel kernels are separable), instead of applying a nD kernel + for k in self.kernels[n]: + gradient = tf.concat([self.convnd(tf.expand_dims(gradient[..., n], -1), k, self.stride, 'SAME') + for n in range(self.n_channels)], -1) + gradients.append(gradient) + + # 1-step method, only supports 2 and 3D + else: + + # get 1-step diff + if self.n_dims == 2: + gradients.append(image[:, 1:, :, :] - image[:, :-1, :, :]) # dx + gradients.append(image[:, :, 1:, :] - image[:, :, :-1, :]) # dy + + elif self.n_dims == 3: + gradients.append(image[:, 1:, :, :, :] - image[:, :-1, :, :, :]) # dx + gradients.append(image[:, :, 1:, :, :] - image[:, :, :-1, :, :]) # dy + gradients.append(image[:, :, :, 1:, :] - image[:, :, :, :-1, :]) # dz + + else: + raise Exception('ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD' % self.n_dims) + + # pad with zeros to return tensors of the same shape as input + for i in range(self.n_dims): + tmp_shape = list(self.shape) + tmp_shape[i] = 1 + zeros = tf.zeros(tf.concat([batchsize, tf.convert_to_tensor(tmp_shape, dtype='int32')], 0), image.dtype) + gradients[i] = tf.concat([gradients[i], zeros], axis=i + 1) + + # compute total gradient magnitude if necessary, or concatenate different gradients along the channel axis + if self.return_magnitude: + gradients = tf.sqrt(tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1)) + else: + gradients = tf.concat(gradients, axis=-1) + + return gradients + + def compute_output_shape(self, input_shape): + if not self.return_magnitude: + input_shape = list(input_shape) + input_shape[-1] = self.n_dims + return tuple(input_shape) + + +class RandomDilationErosion(Layer): + """ + GPU implementation of binary dilation or erosion. The operation can be chosen to be always a dilation, or always an + erosion, or randomly choosing between them for each element of the batch. + The chosen operation is applied to the input with a given probability. Moreover, it is also possible to randomise + the factor of the operation for each element of the mini-batch. + :param min_factor: minimum possible value for the dilation/erosion factor. Must be an integer. + :param max_factor: minimum possible value for the dilation/erosion factor. Must be an integer. + Set it to the same value as min_factor to always perform dilation/erosion with the same factor. + :param prob: probability with which to apply the selected operation to the input. + :param operation: which operation to apply. Can be 'dilation' or 'erosion' or 'random'. + :param return_mask: if operation is erosion and the input of this layer is a label map, we have the + choice to either return the eroded label map or the mask (return_mask=True) + """ + + def __init__(self, min_factor, max_factor, max_factor_dilate=None, prob=1, operation='random', return_mask=False, + **kwargs): + + self.min_factor = min_factor + self.max_factor = max_factor + self.max_factor_dilate = max_factor_dilate if max_factor_dilate is not None else self.max_factor + self.prob = prob + self.operation = operation + self.return_mask = return_mask + self.n_dims = None + self.inshape = None + self.n_channels = None + self.convnd = None + super(RandomDilationErosion, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["min_factor"] = self.min_factor + config["max_factor"] = self.max_factor + config["max_factor_dilate"] = self.max_factor_dilate + config["prob"] = self.prob + config["operation"] = self.operation + config["return_mask"] = self.return_mask + return config + + def build(self, input_shape): + + # input shape + self.inshape = input_shape + self.n_dims = len(self.inshape) - 2 + self.n_channels = self.inshape[-1] + + # prepare convolution + self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + + self.built = True + super(RandomDilationErosion, self).build(input_shape) + + def call(self, inputs, **kwargs): + + # sample probability of applying operation. If random negative is erosion and positive is dilation + batchsize = tf.split(tf.shape(inputs), [1, -1])[0] + shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype='int32')], axis=0) + if self.operation == 'dilation': + prob = tf.random.uniform(shape, 0, 1) + elif self.operation == 'erosion': + prob = tf.random.uniform(shape, -1, 0) + elif self.operation == 'random': + prob = tf.random.uniform(shape, -1, 1) + else: + raise ValueError("operation should either be 'dilation' 'erosion' or 'random', had %s" % self.operation) + + # build kernel + if self.min_factor == self.max_factor: + dist_threshold = self.min_factor * tf.ones(shape, dtype='int32') + else: + if (self.max_factor == self.max_factor_dilate) | (self.operation != 'random'): + dist_threshold = tf.random.uniform(shape, minval=self.min_factor, maxval=self.max_factor, dtype='int32') + else: + dist_threshold = tf.cast(tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), dtype='int32') + kernel = l2i_et.unit_kernel(dist_threshold, self.n_dims, max_dist_threshold=self.max_factor) + + # convolve input mask with kernel according to given probability + mask = tf.cast(tf.cast(inputs, dtype='bool'), dtype='float32') + mask = tf.map_fn(self._single_blur, [mask, kernel, prob], dtype=tf.float32) + mask = tf.cast(mask, 'bool') + + if self.return_mask: + return mask + else: + return inputs * tf.cast(mask, dtype=inputs.dtype) + + def _sample_factor(self, inputs): + return tf.cast(K.switch(K.less(tf.squeeze(inputs[0]), 0), + tf.random.uniform((1,), self.min_factor, self.max_factor, dtype='int32'), + tf.random.uniform((1,), self.min_factor, self.max_factor_dilate, dtype='int32')), + dtype='float32') + + def _single_blur(self, inputs): + # dilate... + new_mask = K.switch(K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001), + tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], + [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), + inputs[0]) + # ...or erode + new_mask = K.switch(K.less(tf.squeeze(inputs[2]), - (1 - self.prob + 0.001)), + 1 - tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(1 - new_mask, 0), inputs[1], + [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), + new_mask) + return new_mask + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/nobrainer/ext/lab2im/utils.py b/nobrainer/ext/lab2im/utils.py new file mode 100644 index 00000000..1f3b8888 --- /dev/null +++ b/nobrainer/ext/lab2im/utils.py @@ -0,0 +1,1057 @@ +""" +This file contains all the utilities used in that project. They are classified in 5 categories: +1- loading/saving functions: + -load_volume + -save_volume + -get_volume_info + -get_list_labels + -load_array_if_path + -write_pickle + -read_pickle + -write_model_summary +2- reformatting functions + -reformat_to_list + -reformat_to_n_channels_array +3- path related functions + -list_images_in_folder + -list_files + -list_subfolders + -strip_extension + -strip_suffix + -mkdir + -mkcmd +4- shape-related functions + -get_dims + -get_resample_shape + -add_axis + -get_padding_margin +5- build affine matrices/tensors + -create_affine_transformation_matrix + -sample_affine_transform + -create_rotation_transform + -create_shearing_transform +6- miscellaneous + -infer + -LoopInfo + -get_mapping_lut + -build_training_generator + -find_closest_number_divisible_by_m + -build_binary_structure + -draw_value_from_distribution + -build_exp + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +import os +import glob +import math +import time +import pickle +import numpy as np +import nibabel as nib +import tensorflow as tf +import keras.layers as KL +import keras.backend as K +from datetime import timedelta +from scipy.ndimage.morphology import distance_transform_edt + + +# ---------------------------------------------- loading/saving functions ---------------------------------------------- + + +def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None): + """ + Load volume file. + :param path_volume: path of the volume to load. Can either be a nii, nii.gz, mgz, or npz format. + If npz format, 1) the variable name is assumed to be 'vol_data', + 2) the volume is associated with an identity affine matrix and blank header. + :param im_only: (optional) if False, the function also returns the affine matrix and header of the volume. + :param squeeze: (optional) whether to squeeze the volume when loading. + :param dtype: (optional) if not None, convert the loaded volume to this numpy dtype. + :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix. + The returned affine matrix is also given in this new space. Must be a numpy array of dimension 4x4. + :return: the volume, with corresponding affine matrix and header if im_only is False. + """ + assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume + + if path_volume.endswith(('.nii', '.nii.gz', '.mgz')): + x = nib.load(path_volume) + if squeeze: + volume = np.squeeze(x.get_fdata()) + else: + volume = x.get_fdata() + aff = x.affine + header = x.header + else: # npz + volume = np.load(path_volume)['vol_data'] + if squeeze: + volume = np.squeeze(volume) + aff = np.eye(4) + header = nib.Nifti1Header() + if dtype is not None: + if 'int' in dtype: + volume = np.round(volume) + volume = volume.astype(dtype=dtype) + + # align image to reference affine matrix + if aff_ref is not None: + from nobrainer.ext.lab2im import edit_volumes # the import is done here to avoid import loops + n_dims, _ = get_dims(list(volume.shape), max_channels=10) + volume, aff = edit_volumes.align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims) + + if im_only: + return volume + else: + return volume, aff, header + + +def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): + """ + Save a volume. + :param volume: volume to save + :param aff: affine matrix of the volume to save. If aff is None, the volume is saved with an identity affine matrix. + aff can also be set to 'FS', in which case the volume is saved with the affine matrix of FreeSurfer outputs. + :param header: header of the volume to save. If None, the volume is saved with a blank header. + :param path: path where to save the volume. + :param res: (optional) update the resolution in the header before saving the volume. + :param dtype: (optional) numpy dtype for the saved volume. + :param n_dims: (optional) number of dimensions, to avoid confusion in multi-channel case. Default is None, where + n_dims is automatically inferred. + """ + + mkdir(os.path.dirname(path)) + if '.npz' in path: + np.savez_compressed(path, vol_data=volume) + else: + if header is None: + header = nib.Nifti1Header() + if isinstance(aff, str): + if aff == 'FS': + aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) + elif aff is None: + aff = np.eye(4) + if dtype is not None: + if 'int' in dtype: + volume = np.round(volume) + volume = volume.astype(dtype=dtype) + nifty = nib.Nifti1Image(volume, aff, header) + nifty.set_data_dtype(dtype) + else: + nifty = nib.Nifti1Image(volume, aff, header) + if res is not None: + if n_dims is None: + n_dims, _ = get_dims(volume.shape) + res = reformat_to_list(res, length=n_dims, dtype=None) + nifty.header.set_zooms(res) + nib.save(nifty, path) + + +def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10): + """ + Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution. + :param path_volume: path of the volume to get information form. + :param return_volume: (optional) whether to return the volume along with the information. + :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix. + All info relative to the volume is then given in this new space. Must be a numpy array of dimension 4x4. + :param max_channels: maximum possible number of channels for the input volume. + :return: volume (if return_volume is true), and corresponding info. If aff_ref is not None, the returned aff is + the original one, i.e. the affine of the image before being aligned to aff_ref. + """ + # read image + im, aff, header = load_volume(path_volume, im_only=False) + + # understand if image is multichannel + im_shape = list(im.shape) + n_dims, n_channels = get_dims(im_shape, max_channels=max_channels) + im_shape = im_shape[:n_dims] + + # get labels res + if '.nii' in path_volume: + data_res = np.array(header['pixdim'][1:n_dims + 1]) + elif '.mgz' in path_volume: + data_res = np.array(header['delta']) # mgz image + else: + data_res = np.array([1.0] * n_dims) + + # align to given affine matrix + if aff_ref is not None: + from nobrainer.ext.lab2im import edit_volumes # the import is done here to avoid import loops + ras_axes = edit_volumes.get_ras_axes(aff, n_dims=n_dims) + ras_axes_ref = edit_volumes.get_ras_axes(aff_ref, n_dims=n_dims) + im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims) + im_shape = np.array(im_shape) + data_res = np.array(data_res) + im_shape[ras_axes_ref] = im_shape[ras_axes] + data_res[ras_axes_ref] = data_res[ras_axes] + im_shape = im_shape.tolist() + + # return info + if return_volume: + return im, im_shape, aff, n_dims, n_channels, header, data_res + else: + return im_shape, aff, n_dims, n_channels, header, data_res + + +def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_sort=False): + """This function reads or computes a list of all label values used in a set of label maps. + It can also sort all labels according to FreeSurfer lut. + :param label_list: (optional) already computed label_list. Can be a sequence, a 1d numpy array, or the path to + a numpy 1d array. + :param labels_dir: (optional) if path_label_list is None, the label list is computed by reading all the label maps + in the given folder. Can also be the path to a single label map. + :param save_label_list: (optional) path where to save the label list. + :param FS_sort: (optional) whether to sort label values according to the FreeSurfer classification. + If true, the label values will be ordered as follows: neutral labels first (i.e. non-sided), left-side labels, + and right-side labels. If FS_sort is True, this function also returns the number of neutral labels in label_list. + :return: the label list (numpy 1d array), and the number of neutral (i.e. non-sided) labels if FS_sort is True. + If one side of the brain is not represented at all in label_list, all labels are considered as neutral, and + n_neutral_labels = len(label_list). + """ + + # load label list if previously computed + if label_list is not None: + label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int')) + + # compute label list from all label files + elif labels_dir is not None: + print('Compiling list of unique labels') + # go through all labels files and compute unique list of labels + labels_paths = list_images_in_folder(labels_dir) + label_list = np.empty(0) + loop_info = LoopInfo(len(labels_paths), 10, 'processing', print_time=True) + for lab_idx, path in enumerate(labels_paths): + loop_info.update(lab_idx) + y = load_volume(path, dtype='int32') + y_unique = np.unique(y) + label_list = np.unique(np.concatenate((label_list, y_unique))).astype('int') + + else: + raise Exception('either label_list, path_label_list or labels_dir should be provided') + + # sort labels in neutral/left/right according to FS labels + n_neutral_labels = 0 + if FS_sort: + neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108, + 109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, + 502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530, + 531, 532, 533, 534, 535, 536, 537] + neutral = list() + left = list() + right = list() + for la in label_list: + if la in neutral_FS_labels: + if la not in neutral: + neutral.append(la) + elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \ + (la == 865) | (20100 < la < 20110): + if la not in left: + left.append(la) + elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \ + (la == 866): + if la not in right: + right.append(la) + else: + raise Exception('label {} not in our current FS classification, ' + 'please update get_list_labels in utils.py'.format(la)) + label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)]) + if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)): + n_neutral_labels = len(neutral) + else: + n_neutral_labels = len(label_list) + + # save labels if specified + if save_label_list is not None: + np.save(save_label_list, np.int32(label_list)) + + if FS_sort: + return np.int32(label_list), n_neutral_labels + else: + return np.int32(label_list), None + + +def load_array_if_path(var, load_as_numpy=True): + """If var is a string and load_as_numpy is True, this function loads the array writen at the path indicated by var. + Otherwise it simply returns var as it is.""" + if (isinstance(var, str)) & load_as_numpy: + assert os.path.isfile(var), 'No such path: %s' % var + var = np.load(var) + return var + + +def write_pickle(filepath, obj): + """ write a python object with a pickle at a given path""" + with open(filepath, 'wb') as file: + pickler = pickle.Pickler(file) + pickler.dump(obj) + + +def read_pickle(filepath): + """ read a python object with a pickle""" + with open(filepath, 'rb') as file: + unpickler = pickle.Unpickler(file) + return unpickler.load() + + +def write_model_summary(model, filepath='./model_summary.txt', line_length=150): + """Write the summary of a keras model at a given path, with a given length for each line""" + with open(filepath, 'w') as fh: + model.summary(print_fn=lambda x: fh.write(x + '\n'), line_length=line_length) + + +# ----------------------------------------------- reformatting functions ----------------------------------------------- + + +def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): + """This function takes a variable and reformat it into a list of desired + length and type (int, float, bool, str). + If variable is a string, and load_as_numpy is True, it will be loaded as a numpy array. + If variable is None, this function returns None. + :param var: a str, int, float, list, tuple, or numpy array + :param length: (optional) if var is a single item, it will be replicated to a list of this length + :param load_as_numpy: (optional) whether var is the path to a numpy array + :param dtype: (optional) convert all item to this type. Can be 'int', 'float', 'bool', or 'str' + :return: reformatted list + """ + + # convert to list + if var is None: + return None + var = load_array_if_path(var, load_as_numpy=load_as_numpy) + if isinstance(var, (int, float, np.int32, np.int64, np.float32, np.float64)): + var = [var] + elif isinstance(var, tuple): + var = list(var) + elif isinstance(var, np.ndarray): + if var.shape == (1,): + var = [var[0]] + else: + var = np.squeeze(var).tolist() + elif isinstance(var, str): + var = [var] + elif isinstance(var, bool): + var = [var] + if isinstance(var, list): + if length is not None: + if len(var) == 1: + var = var * length + elif len(var) != length: + raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, ' + 'had {1}'.format(length, var)) + else: + raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array') + + # convert items type + if dtype is not None: + if dtype == 'int': + var = [int(v) for v in var] + elif dtype == 'float': + var = [float(v) for v in var] + elif dtype == 'bool': + var = [bool(v) for v in var] + elif dtype == 'str': + var = [str(v) for v in var] + else: + raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype)) + return var + + +def reformat_to_n_channels_array(var, n_dims=3, n_channels=1): + """This function takes an int, float, list or tuple and reformat it to an array of shape (n_channels, n_dims). + If resolution is a str, it will be assumed to be the path of a numpy array. + If resolution is a numpy array, it will be checked to have shape (n_channels, n_dims). + Finally if resolution is None, this function returns None as well.""" + if var is None: + return [None] * n_channels + if isinstance(var, str): + var = np.load(var) + # convert to numpy array + if isinstance(var, (int, float, list, tuple)): + var = reformat_to_list(var, n_dims) + var = np.tile(np.array(var), (n_channels, 1)) + # check shape if numpy array + elif isinstance(var, np.ndarray): + if n_channels == 1: + var = var.reshape((1, n_dims)) + else: + if np.squeeze(var).shape == (n_dims,): + var = np.tile(var.reshape((1, n_dims)), (n_channels, 1)) + elif var.shape != (n_channels, n_dims): + raise ValueError('if array, var should be {0} or {1}'.format((1, n_dims), (n_channels, n_dims))) + else: + raise TypeError('var should be int, float, list, tuple or ndarray') + return np.round(var, 3) + + +# ----------------------------------------------- path-related functions ----------------------------------------------- + + +def list_images_in_folder(path_dir, include_single_image=True, check_if_empty=True): + """List all files with extension nii, nii.gz, mgz, or npz within a folder.""" + basename = os.path.basename(path_dir) + if include_single_image & \ + (('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename)): + assert os.path.isfile(path_dir), 'file %s does not exist' % path_dir + list_images = [path_dir] + else: + if os.path.isdir(path_dir): + list_images = sorted(glob.glob(os.path.join(path_dir, '*nii.gz')) + + glob.glob(os.path.join(path_dir, '*nii')) + + glob.glob(os.path.join(path_dir, '*.mgz')) + + glob.glob(os.path.join(path_dir, '*.npz'))) + else: + raise Exception('Folder does not exist: %s' % path_dir) + if check_if_empty: + assert len(list_images) > 0, 'no .nii, .nii.gz, .mgz or .npz image could be found in %s' % path_dir + return list_images + + +def list_files(path_dir, whole_path=True, expr=None, cond_type='or'): + """This function returns a list of files contained in a folder, with possible regexp. + :param path_dir: path of a folder + :param whole_path: (optional) whether to return whole path or just the filenames. + :param expr: (optional) regexp for files to list. Can be a str or a list of str. + :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp. + Can be 'or', or 'and'. + :return: a list of files + """ + assert isinstance(whole_path, bool), "whole_path should be bool" + assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'" + if whole_path: + files_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir) + if os.path.isfile(os.path.join(path_dir, f))]) + else: + files_list = sorted([f for f in os.listdir(path_dir) if os.path.isfile(os.path.join(path_dir, f))]) + if expr is not None: # assumed to be either str or list of str + if isinstance(expr, str): + expr = [expr] + elif not isinstance(expr, (list, tuple)): + raise Exception("if specified, 'expr' should be a string or list of strings.") + matched_list_files = list() + for match in expr: + tmp_matched_files_list = sorted([f for f in files_list if match in os.path.basename(f)]) + if cond_type == 'or': + files_list = [f for f in files_list if f not in tmp_matched_files_list] + matched_list_files += tmp_matched_files_list + elif cond_type == 'and': + files_list = tmp_matched_files_list + matched_list_files = tmp_matched_files_list + files_list = sorted(matched_list_files) + return files_list + + +def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'): + """This function returns a list of subfolders contained in a folder, with possible regexp. + :param path_dir: path of a folder + :param whole_path: (optional) whether to return whole path or just the subfolder names. + :param expr: (optional) regexp for files to list. Can be a str or a list of str. + :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp. + Can be 'or', or 'and'. + :return: a list of subfolders + """ + assert isinstance(whole_path, bool), "whole_path should be bool" + assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'" + if whole_path: + subdirs_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir) + if os.path.isdir(os.path.join(path_dir, f))]) + else: + subdirs_list = sorted([f for f in os.listdir(path_dir) if os.path.isdir(os.path.join(path_dir, f))]) + if expr is not None: # assumed to be either str or list of str + if isinstance(expr, str): + expr = [expr] + elif not isinstance(expr, (list, tuple)): + raise Exception("if specified, 'expr' should be a string or list of strings.") + matched_list_subdirs = list() + for match in expr: + tmp_matched_list_subdirs = sorted([f for f in subdirs_list if match in os.path.basename(f)]) + if cond_type == 'or': + subdirs_list = [f for f in subdirs_list if f not in tmp_matched_list_subdirs] + matched_list_subdirs += tmp_matched_list_subdirs + elif cond_type == 'and': + subdirs_list = tmp_matched_list_subdirs + matched_list_subdirs = tmp_matched_list_subdirs + subdirs_list = sorted(matched_list_subdirs) + return subdirs_list + + +def get_image_extension(path): + name = os.path.basename(path) + if name[-7:] == '.nii.gz': + return 'nii.gz' + elif name[-4:] == '.mgz': + return 'mgz' + elif name[-4:] == '.nii': + return 'nii' + elif name[-4:] == '.npz': + return 'npz' + + +def strip_extension(path): + """Strip classical image extensions (.nii.gz, .nii, .mgz, .npz) from a filename.""" + return path.replace('.nii.gz', '').replace('.nii', '').replace('.mgz', '').replace('.npz', '') + + +def strip_suffix(path): + """Strip classical image suffix from a filename.""" + path = path.replace('_aseg', '') + path = path.replace('aseg', '') + path = path.replace('.aseg', '') + path = path.replace('_aseg_1', '') + path = path.replace('_aseg_2', '') + path = path.replace('aseg_1_', '') + path = path.replace('aseg_2_', '') + path = path.replace('_orig', '') + path = path.replace('orig', '') + path = path.replace('.orig', '') + path = path.replace('_norm', '') + path = path.replace('norm', '') + path = path.replace('.norm', '') + path = path.replace('_talairach', '') + path = path.replace('GSP_FS_4p5', 'GSP') + path = path.replace('.nii_crispSegmentation', '') + path = path.replace('_crispSegmentation', '') + path = path.replace('_seg', '') + path = path.replace('.seg', '') + path = path.replace('seg', '') + path = path.replace('_seg_1', '') + path = path.replace('_seg_2', '') + path = path.replace('seg_1_', '') + path = path.replace('seg_2_', '') + return path + + +def mkdir(path_dir): + """Recursively creates the current dir as well as its parent folders if they do not already exist.""" + if path_dir[-1] == '/': + path_dir = path_dir[:-1] + if not os.path.isdir(path_dir): + list_dir_to_create = [path_dir] + while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])): + list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1])) + for dir_to_create in reversed(list_dir_to_create): + os.mkdir(dir_to_create) + + +def mkcmd(*args): + """Creates terminal command with provided inputs. + Example: mkcmd('mv', 'source', 'dest') will give 'mv source dest'.""" + return ' '.join([str(arg) for arg in args]) + + +# ---------------------------------------------- shape-related functions ----------------------------------------------- + + +def get_dims(shape, max_channels=10): + """Get the number of dimensions and channels from the shape of an array. + The number of dimensions is assumed to be the length of the shape, as long as the shape of the last dimension is + inferior or equal to max_channels (default 3). + :param shape: shape of an array. Can be a sequence or a 1d numpy array. + :param max_channels: maximum possible number of channels. + :return: the number of dimensions and channels associated with the provided shape. + example 1: get_dims([150, 150, 150], max_channels=10) = (3, 1) + example 2: get_dims([150, 150, 150, 3], max_channels=10) = (3, 3) + example 3: get_dims([150, 150, 150, 15], max_channels=10) = (4, 1), because 5>3""" + if shape[-1] <= max_channels: + n_dims = len(shape) - 1 + n_channels = shape[-1] + else: + n_dims = len(shape) + n_channels = 1 + return n_dims, n_channels + + +def get_resample_shape(patch_shape, factor, n_channels=None): + """Compute the shape of a resampled array given a shape factor. + :param patch_shape: size of the initial array (without number of channels). + :param factor: resampling factor. Can be a number, sequence, or 1d numpy array. + :param n_channels: (optional) if not None, add a number of channel at the end of the computed shape. + :return: list containing the shape of the input array after being resampled by the given factor. + """ + factor = reformat_to_list(factor, length=len(patch_shape)) + shape = [math.ceil(patch_shape[i] * factor[i]) for i in range(len(patch_shape))] + if n_channels is not None: + shape += [n_channels] + return shape + + +def add_axis(x, axis=0): + """Add axis to a numpy array. + :param x: input array + :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time.""" + axis = reformat_to_list(axis) + for ax in axis: + x = np.expand_dims(x, axis=ax) + return x + + +def get_padding_margin(cropping, loss_cropping): + """Compute padding margin""" + if (cropping is not None) & (loss_cropping is not None): + cropping = reformat_to_list(cropping) + loss_cropping = reformat_to_list(loss_cropping) + n_dims = max(len(cropping), len(loss_cropping)) + cropping = reformat_to_list(cropping, length=n_dims) + loss_cropping = reformat_to_list(loss_cropping, length=n_dims) + padding_margin = [int((cropping[i] - loss_cropping[i]) / 2) for i in range(n_dims)] + if len(padding_margin) == 1: + padding_margin = padding_margin[0] + else: + padding_margin = None + return padding_margin + + +# -------------------------------------------- build affine matrices/tensors ------------------------------------------- + + +def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, shearing=None, translation=None): + """Create a 4x4 affine transformation matrix from specified values + :param n_dims: integer, can either be 2 or 3. + :param scaling: list of 3 scaling values + :param rotation: list of 3 angles (degrees) for rotations around 1st, 2nd, 3rd axis + :param shearing: list of 6 shearing values + :param translation: list of 3 values + :return: 4x4 numpy matrix + """ + + T_scaling = np.eye(n_dims + 1) + T_shearing = np.eye(n_dims + 1) + T_translation = np.eye(n_dims + 1) + + if scaling is not None: + T_scaling[np.arange(n_dims + 1), np.arange(n_dims + 1)] = np.append(scaling, 1) + + if shearing is not None: + shearing_index = np.ones((n_dims + 1, n_dims + 1), dtype='bool') + shearing_index[np.eye(n_dims + 1, dtype='bool')] = False + shearing_index[-1, :] = np.zeros((n_dims + 1)) + shearing_index[:, -1] = np.zeros((n_dims + 1)) + T_shearing[shearing_index] = shearing + + if translation is not None: + T_translation[np.arange(n_dims), n_dims * np.ones(n_dims, dtype='int')] = translation + + if n_dims == 2: + if rotation is None: + rotation = np.zeros(1) + else: + rotation = np.asarray(rotation) * (math.pi / 180) + T_rot = np.eye(n_dims + 1) + T_rot[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[0]), np.sin(rotation[0]), + np.sin(rotation[0]) * -1, np.cos(rotation[0])] + return T_translation @ T_rot @ T_shearing @ T_scaling + + else: + + if rotation is None: + rotation = np.zeros(n_dims) + else: + rotation = np.asarray(rotation) * (math.pi / 180) + T_rot1 = np.eye(n_dims + 1) + T_rot1[np.array([1, 2, 1, 2]), np.array([1, 1, 2, 2])] = [np.cos(rotation[0]), np.sin(rotation[0]), + np.sin(rotation[0]) * -1, np.cos(rotation[0])] + T_rot2 = np.eye(n_dims + 1) + T_rot2[np.array([0, 2, 0, 2]), np.array([0, 0, 2, 2])] = [np.cos(rotation[1]), np.sin(rotation[1]) * -1, + np.sin(rotation[1]), np.cos(rotation[1])] + T_rot3 = np.eye(n_dims + 1) + T_rot3[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[2]), np.sin(rotation[2]), + np.sin(rotation[2]) * -1, np.cos(rotation[2])] + return T_translation @ T_rot3 @ T_rot2 @ T_rot1 @ T_shearing @ T_scaling + + +def sample_affine_transform(batchsize, + n_dims, + rotation_bounds=False, + scaling_bounds=False, + shearing_bounds=False, + translation_bounds=False, + enable_90_rotations=False): + """build batchsize x 4 x 4 tensor representing an affine transformation in homogeneous coordinates. + If return_inv is True, also returns the inverse of the created affine matrix.""" + + if (rotation_bounds is not False) | (enable_90_rotations is not False): + if n_dims == 2: + if rotation_bounds is not False: + rotation = draw_value_from_distribution(rotation_bounds, + size=1, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize) + else: + rotation = tf.zeros(tf.concat([batchsize, tf.ones(1, dtype='int32')], axis=0)) + else: # n_dims = 3 + if rotation_bounds is not False: + rotation = draw_value_from_distribution(rotation_bounds, + size=n_dims, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize) + else: + rotation = tf.zeros(tf.concat([batchsize, 3 * tf.ones(1, dtype='int32')], axis=0)) + if enable_90_rotations: + rotation = tf.cast(tf.random.uniform(tf.shape(rotation), maxval=4, dtype='int32') * 90, 'float32') \ + + rotation + T_rot = create_rotation_transform(rotation, n_dims) + else: + T_rot = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + + if shearing_bounds is not False: + shearing = draw_value_from_distribution(shearing_bounds, + size=n_dims ** 2 - n_dims, + default_range=.01, + return_as_tensor=True, + batchsize=batchsize) + T_shearing = create_shearing_transform(shearing, n_dims) + else: + T_shearing = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + + if scaling_bounds is not False: + scaling = draw_value_from_distribution(scaling_bounds, + size=n_dims, + centre=1, + default_range=.15, + return_as_tensor=True, + batchsize=batchsize) + T_scaling = tf.linalg.diag(scaling) + else: + T_scaling = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + + T = tf.matmul(T_scaling, tf.matmul(T_shearing, T_rot)) + + if translation_bounds is not False: + translation = draw_value_from_distribution(translation_bounds, + size=n_dims, + default_range=5, + return_as_tensor=True, + batchsize=batchsize) + T = tf.concat([T, tf.expand_dims(translation, axis=-1)], axis=-1) + else: + T = tf.concat([T, tf.zeros(tf.concat([tf.shape(T)[:2], tf.ones(1, dtype='int32')], 0))], axis=-1) + + # build rigid transform + T_last_row = tf.expand_dims(tf.concat([tf.zeros((1, n_dims)), tf.ones((1, 1))], axis=1), 0) + T_last_row = tf.tile(T_last_row, tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T = tf.concat([T, T_last_row], axis=1) + + return T + + +def create_rotation_transform(rotation, n_dims): + """build rotation transform from 3d or 2d rotation coefficients. Angles are given in degrees.""" + rotation = rotation * np.pi / 180 + if n_dims == 3: + shape = tf.shape(tf.expand_dims(rotation[..., 0], -1)) + + Rx_row0 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([1., 0., 0.]), 0), shape), axis=1) + Rx_row1 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(-tf.sin(rotation[..., 0]), -1)], axis=-1) + Rx_row2 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1) + Rx = tf.concat([Rx_row0, Rx_row1, Rx_row2], axis=1) + + Ry_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 1]), -1), tf.zeros(shape), + tf.expand_dims(tf.sin(rotation[..., 1]), -1)], axis=-1) + Ry_row1 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 1., 0.]), 0), shape), axis=1) + Ry_row2 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 1]), -1), tf.zeros(shape), + tf.expand_dims(tf.cos(rotation[..., 1]), -1)], axis=-1) + Ry = tf.concat([Ry_row0, Ry_row1, Ry_row2], axis=1) + + Rz_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 2]), -1), + tf.expand_dims(-tf.sin(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1) + Rz_row1 = tf.stack([tf.expand_dims(tf.sin(rotation[..., 2]), -1), + tf.expand_dims(tf.cos(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1) + Rz_row2 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 0., 1.]), 0), shape), axis=1) + Rz = tf.concat([Rz_row0, Rz_row1, Rz_row2], axis=1) + + T_rot = tf.matmul(tf.matmul(Rx, Ry), Rz) + + elif n_dims == 2: + R_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(tf.sin(rotation[..., 0]), -1)], axis=-1) + R_row1 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1) + T_rot = tf.concat([R_row0, R_row1], axis=1) + + else: + raise Exception('only supports 2 or 3D.') + + return T_rot + + +def create_shearing_transform(shearing, n_dims): + """build shearing transform from 2d/3d shearing coefficients""" + shape = tf.shape(tf.expand_dims(shearing[..., 0], -1)) + if n_dims == 3: + shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1), + tf.expand_dims(shearing[..., 1], -1)], axis=-1) + shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 2], -1), tf.ones(shape), + tf.expand_dims(shearing[..., 3], -1)], axis=-1) + shearing_row2 = tf.stack([tf.expand_dims(shearing[..., 4], -1), tf.expand_dims(shearing[..., 5], -1), + tf.ones(shape)], axis=-1) + T_shearing = tf.concat([shearing_row0, shearing_row1, shearing_row2], axis=1) + + elif n_dims == 2: + shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1)], axis=-1) + shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 1], -1), tf.ones(shape)], axis=-1) + T_shearing = tf.concat([shearing_row0, shearing_row1], axis=1) + else: + raise Exception('only supports 2 or 3D.') + return T_shearing + + +# --------------------------------------------------- miscellaneous ---------------------------------------------------- + + +def infer(x): + """ Try to parse input to float. If it fails, tries boolean, and otherwise keep it as string """ + try: + x = float(x) + except ValueError: + if x == 'False': + x = False + elif x == 'True': + x = True + elif not isinstance(x, str): + raise TypeError('input should be an int/float/boolean/str, had {}'.format(type(x))) + return x + + +class LoopInfo: + """ + Class to print the current iteration in a for loop, and optionally the estimated remaining time. + Instantiate just before the loop, and call the update method at the start of the loop. + The printed text has the following format: + processing i/total remaining time: hh:mm:ss + """ + + def __init__(self, n_iterations, spacing=10, text='processing', print_time=False): + """ + :param n_iterations: total number of iterations of the for loop. + :param spacing: frequency at which the update info will be printed on screen. + :param text: text to print. Default is processing. + :param print_time: whether to print the estimated remaining time. Default is False. + """ + + # loop parameters + self.n_iterations = n_iterations + self.spacing = spacing + + # text parameters + self.text = text + self.print_time = print_time + self.print_previous_time = False + self.align = len(str(self.n_iterations)) * 2 + 1 + 3 + + # timing parameters + self.iteration_durations = np.zeros((n_iterations,)) + self.start = time.time() + self.previous = time.time() + + def update(self, idx): + + # time iteration + now = time.time() + self.iteration_durations[idx] = now - self.previous + self.previous = now + + # print text + if idx == 0: + print(self.text + ' 1/{}'.format(self.n_iterations)) + elif idx % self.spacing == self.spacing - 1: + iteration = str(idx + 1) + '/' + str(self.n_iterations) + if self.print_time: + # estimate remaining time + max_duration = np.max(self.iteration_durations) + average_duration = np.mean(self.iteration_durations[self.iteration_durations > .01 * max_duration]) + remaining_time = int(average_duration * (self.n_iterations - idx)) + # print total remaining time only if it is greater than 1s or if it was previously printed + if (remaining_time > 1) | self.print_previous_time: + eta = str(timedelta(seconds=remaining_time)) + print(self.text + ' {:<{x}} remaining time: {}'.format(iteration, eta, x=self.align)) + self.print_previous_time = True + else: + print(self.text + ' {}'.format(iteration)) + else: + print(self.text + ' {}'.format(iteration)) + + +def get_mapping_lut(source, dest=None): + """This functions returns the look-up table to map a list of N values (source) to another list (dest). + If the second list is not given, we assume it is equal to [0, ..., N-1].""" + + # initialise + source = np.array(reformat_to_list(source), dtype='int32') + n_labels = source.shape[0] + + # build new label list if necessary + if dest is None: + dest = np.arange(n_labels, dtype='int32') + else: + assert len(source) == len(dest), 'label_list and new_label_list should have the same length' + dest = np.array(reformat_to_list(dest, dtype='int')) + + # build look-up table + lut = np.zeros(np.max(source) + 1, dtype='int32') + for source, dest in zip(source, dest): + lut[source] = dest + + return lut + + +def build_training_generator(gen, batchsize): + """Build generator for training a network.""" + while True: + inputs = next(gen) + if batchsize > 1: + target = np.concatenate([np.zeros((1, 1))] * batchsize, 0) + else: + target = np.zeros((1, 1)) + yield inputs, target + + +def find_closest_number_divisible_by_m(n, m, answer_type='lower'): + """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns + values lower than n), or 'higher' (only returns values higher than m).""" + if n % m == 0: + return n + else: + q = int(n / m) + lower = q * m + higher = (q + 1) * m + if answer_type == 'lower': + return lower + elif answer_type == 'higher': + return higher + elif answer_type == 'closer': + return lower if (n - lower) < (higher - n) else higher + else: + raise Exception('answer_type should be lower, higher, or closer, had : %s' % answer_type) + + +def build_binary_structure(connectivity, n_dims, shape=None): + """Return a dilation/erosion element with provided connectivity""" + if shape is None: + shape = [connectivity * 2 + 1] * n_dims + else: + shape = reformat_to_list(shape, length=n_dims) + dist = np.ones(shape) + center = tuple([tuple([int(s / 2)]) for s in shape]) + dist[center] = 0 + dist = distance_transform_edt(dist) + struct = (dist <= connectivity) * 1 + return struct + + +def draw_value_from_distribution(hyperparameter, + size=1, + distribution='uniform', + centre=0., + default_range=10.0, + positive_only=False, + return_as_tensor=False, + batchsize=None): + """Sample values from a uniform, or normal distribution of given hyperparameters. + These hyperparameters are to the number of 2 in both uniform and normal cases. + :param hyperparameter: values of the hyperparameters. Can either be: + 1) None, in each case the two hyperparameters are given by [center-default_range, center+default_range], + 2) a number, where the two hyperparameters are given by [centre-hyperparameter, centre+hyperparameter], + 3) a sequence of length 2, directly defining the two hyperparameters: [min, max] if the distribution is uniform, + [mean, std] if the distribution is normal. + 4) a numpy array, with size (2, m). In this case, the function returns a 1d array of size m, where each value has + been sampled independently with the specified hyperparameters. If the distribution is uniform, rows correspond to + its lower and upper bounds, and if the distribution is normal, rows correspond to its mean and std deviation. + 5) a numpy array of size (2*n, m). Same as 4) but we first randomly select a block of two rows among the + n possibilities. + 6) the path to a numpy array corresponding to case 4 or 5. + 7) False, in which case this function returns None. + :param size: (optional) number of values to sample. All values are sampled independently. + Used only if hyperparameter is not a numpy array. + :param distribution: (optional) the distribution type. Can be 'uniform' or 'normal'. Default is 'uniform'. + :param centre: (optional) default centre to use if hyperparameter is None or a number. + :param default_range: (optional) default range to use if hyperparameter is None. + :param positive_only: (optional) whether to reset all negative values to zero. + :param return_as_tensor: (optional) whether to return the result as a tensorflow tensor + :param batchsize: (optional) if return_as_tensor is true, then you can sample a tensor of a given batchsize. Give + this batchsize as a tensorflow tensor here. + :return: a float, or a numpy 1d array if size > 1, or hyperparameter is itself a numpy array. + Returns None if hyperparameter is False. + """ + + # return False is hyperparameter is False + if hyperparameter is False: + return None + + # reformat parameter_range + hyperparameter = load_array_if_path(hyperparameter, load_as_numpy=True) + if not isinstance(hyperparameter, np.ndarray): + if hyperparameter is None: + hyperparameter = np.array([[centre - default_range] * size, [centre + default_range] * size]) + elif isinstance(hyperparameter, (int, float)): + hyperparameter = np.array([[centre - hyperparameter] * size, [centre + hyperparameter] * size]) + elif isinstance(hyperparameter, (list, tuple)): + assert len(hyperparameter) == 2, 'if list, parameter_range should be of length 2.' + hyperparameter = np.transpose(np.tile(np.array(hyperparameter), (size, 1))) + else: + raise ValueError('parameter_range should either be None, a number, a sequence, or a numpy array.') + elif isinstance(hyperparameter, np.ndarray): + assert hyperparameter.shape[0] % 2 == 0, 'number of rows of parameter_range should be divisible by 2' + n_modalities = int(hyperparameter.shape[0] / 2) + modality_idx = 2 * np.random.randint(n_modalities) + hyperparameter = hyperparameter[modality_idx: modality_idx + 2, :] + + # draw values as tensor + if return_as_tensor: + shape = KL.Lambda(lambda x: tf.convert_to_tensor(hyperparameter.shape[1], 'int32'))([]) + if batchsize is not None: + shape = KL.Lambda(lambda x: tf.concat([x[0], tf.expand_dims(x[1], axis=0)], axis=0))([batchsize, shape]) + if distribution == 'uniform': + parameter_value = KL.Lambda(lambda x: tf.random.uniform(shape=x, + minval=hyperparameter[0, :], + maxval=hyperparameter[1, :]))(shape) + elif distribution == 'normal': + parameter_value = KL.Lambda(lambda x: tf.random.normal(shape=x, + mean=hyperparameter[0, :], + stddev=hyperparameter[1, :]))(shape) + else: + raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.") + + if positive_only: + parameter_value = KL.Lambda(lambda x: K.clip(x, 0, None))(parameter_value) + + # draw values as numpy array + else: + if distribution == 'uniform': + parameter_value = np.random.uniform(low=hyperparameter[0, :], high=hyperparameter[1, :]) + elif distribution == 'normal': + parameter_value = np.random.normal(loc=hyperparameter[0, :], scale=hyperparameter[1, :]) + else: + raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.") + + if positive_only: + parameter_value[parameter_value < 0] = 0 + + return parameter_value + + +def build_exp(x, first, last, fix_point): + # first = f(0), last = f(+inf), fix_point = [x0, f(x0))] + a = last + b = first - last + c = - (1 / fix_point[0]) * np.log((fix_point[1] - last) / (first - last)) + return a + b * np.exp(-c * x) diff --git a/nobrainer/ext/neuron/__init__.py b/nobrainer/ext/neuron/__init__.py new file mode 100644 index 00000000..2f28f4d4 --- /dev/null +++ b/nobrainer/ext/neuron/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import models +from . import utils diff --git a/nobrainer/ext/neuron/layers.py b/nobrainer/ext/neuron/layers.py new file mode 100644 index 00000000..61b46a78 --- /dev/null +++ b/nobrainer/ext/neuron/layers.py @@ -0,0 +1,435 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +or for the transformation/integration functions: + +Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration +Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu +MICCAI 2018. + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +# third party +import tensorflow as tf +from keras import backend as K +from keras.layers import Layer +from copy import deepcopy + +# local +from nobrainer.ext.neuron.utils import transform, resize, integrate_vec, affine_to_shift, combine_non_linear_and_aff_to_shift + + +class SpatialTransformer(Layer): + """ + N-D Spatial Transformer Tensorflow / Keras Layer + + The Layer can handle both affine and dense transforms. + Both transforms are meant to give a 'shift' from the current position. + Therefore, a dense transform gives displacements (not absolute locations) at each voxel, + and an affine transform gives the *difference* of the affine matrix from + the identity matrix. + + If you find this function useful, please cite: + Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration + Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu + MICCAI 2018. + + Originally, this code was based on voxelmorph code, which + was in turn transformed to be dense with the help of (affine) STN code + via https://github.com/kevinzakka/spatial-transformer-network + + Since then, we've re-written the code to be generalized to any + dimensions, and along the way wrote grid and interpolation functions + """ + + def __init__(self, + interp_method='linear', + indexing='ij', + single_transform=False, + **kwargs): + """ + Parameters: + interp_method: 'linear' or 'nearest' + single_transform: whether a single transform supplied for the whole batch + indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian) + 'xy' indexing will have the first two entries of the flow + (along last axis) flipped compared to 'ij' indexing + """ + self.interp_method = interp_method + self.ndims = None + self.inshape = None + self.single_transform = single_transform + self.is_affine = list() + + assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" + self.indexing = indexing + + super(self.__class__, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["interp_method"] = self.interp_method + config["indexing"] = self.indexing + config["single_transform"] = self.single_transform + return config + + def build(self, input_shape): + """ + input_shape should be a list for two inputs: + input1: image. + input2: list of transform Tensors + if affine: + should be an N+1 x N+1 matrix + *or* a N+1*N+1 tensor (which will be reshaped to N x (N+1) and an identity row added) + if not affine: + should be a *vol_shape x N + """ + + if len(input_shape) > 3: + raise Exception('Spatial Transformer must be called on a list of min length 2 and max length 3.' + 'First argument is the image followed by the affine and non linear transforms.') + + # set up number of dimensions + self.ndims = len(input_shape[0]) - 2 + self.inshape = input_shape + trf_shape = [trans_shape[1:] for trans_shape in input_shape[1:]] + + for (i, shape) in enumerate(trf_shape): + + # the transform is an affine iff: + # it's a 1D Tensor [dense transforms need to be at least ndims + 1] + # it's a 2D Tensor and shape == [N+1, N+1]. + self.is_affine.append(len(shape) == 1 or + (len(shape) == 2 and all([f == (self.ndims + 1) for f in shape]))) + + # check sizes + if self.is_affine[i] and len(shape) == 1: + ex = self.ndims * (self.ndims + 1) + if shape[0] != ex: + raise Exception('Expected flattened affine of len %d but got %d' % (ex, shape[0])) + + if not self.is_affine[i]: + if shape[-1] != self.ndims: + raise Exception('Offset flow field size expected: %d, found: %d' % (self.ndims, shape[-1])) + + # confirm built + self.built = True + + def call(self, inputs, **kwargs): + """ + Parameters + inputs: list with several entries: the volume followed by the transforms + """ + + # check shapes + assert 1 < len(inputs) < 4, "inputs has to be len 2 or 3, found: %d" % len(inputs) + vol = inputs[0] + trf = inputs[1:] + + # necessary for multi_gpu models... + vol = K.reshape(vol, [-1, *self.inshape[0][1:]]) + for i in range(len(trf)): + trf[i] = K.reshape(trf[i], [-1, *self.inshape[i+1][1:]]) + + # reorder transforms, non-linear first and affine second + ind_nonlinear_linear = [i[0] for i in sorted(enumerate(self.is_affine), key=lambda x:x[1])] + self.is_affine = [self.is_affine[i] for i in ind_nonlinear_linear] + self.inshape = [self.inshape[i] for i in ind_nonlinear_linear] + trf = [trf[i] for i in ind_nonlinear_linear] + + # go from affine to deformation field + if len(trf) == 1: + trf = trf[0] + if self.is_affine[0]: + trf = tf.map_fn(lambda x: self._single_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32) + # combine non-linear and affine to obtain a single deformation field + elif len(trf) == 2: + trf = tf.map_fn(lambda x: self._non_linear_and_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32) + + # prepare location shift + if self.indexing == 'xy': # shift the first two dimensions + trf_split = tf.split(trf, trf.shape[-1], axis=-1) + trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]] + trf = tf.concat(trf_lst, -1) + + # map transform across batch + if self.single_transform: + return tf.map_fn(self._single_transform, [vol, trf[0, :]], dtype=tf.float32) + else: + return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32) + + def _single_aff_to_shift(self, trf, volshape): + if len(trf.shape) == 1: # go from vector to matrix + trf = tf.reshape(trf, [self.ndims, self.ndims + 1]) + return affine_to_shift(trf, volshape, shift_center=True) + + def _non_linear_and_aff_to_shift(self, trf, volshape): + if len(trf[1].shape) == 1: # go from vector to matrix + trf[1] = tf.reshape(trf[1], [self.ndims, self.ndims + 1]) + return combine_non_linear_and_aff_to_shift(trf, volshape, shift_center=True) + + def _single_transform(self, inputs): + return transform(inputs[0], inputs[1], interp_method=self.interp_method) + + +class VecInt(Layer): + """ + Vector Integration Layer + + Enables vector integration via several methods + (ode or quadrature for time-dependent vector fields, + scaling and squaring for stationary fields) + + If you find this function useful, please cite: + Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration + Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu + MICCAI 2018. + """ + + def __init__(self, indexing='ij', method='ss', int_steps=7, out_time_pt=1, + ode_args=None, + odeint_fn=None, **kwargs): + """ + Parameters: + method can be any of the methods in neuron.utils.integrate_vec + indexing can be 'xy' (switches first two dimensions) or 'ij' + int_steps is the number of integration steps + out_time_pt is time point at which to output if using odeint integration + """ + + assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" + self.indexing = indexing + self.method = method + self.int_steps = int_steps + self.inshape = None + self.out_time_pt = out_time_pt + self.odeint_fn = odeint_fn # if none then will use a tensorflow function + self.ode_args = ode_args + if ode_args is None: + self.ode_args = {'rtol': 1e-6, 'atol': 1e-12} + super(self.__class__, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["indexing"] = self.indexing + config["method"] = self.method + config["int_steps"] = self.int_steps + config["out_time_pt"] = self.out_time_pt + config["ode_args"] = self.ode_args + config["odeint_fn"] = self.odeint_fn + return config + + def build(self, input_shape): + # confirm built + self.built = True + + trf_shape = input_shape + if isinstance(input_shape[0], (list, tuple)): + trf_shape = input_shape[0] + self.inshape = trf_shape + + if trf_shape[-1] != len(trf_shape) - 2: + raise Exception('transform ndims %d does not match expected ndims %d' % (trf_shape[-1], len(trf_shape) - 2)) + + def call(self, inputs, **kwargs): + if not isinstance(inputs, (list, tuple)): + inputs = [inputs] + loc_shift = inputs[0] + + # necessary for multi_gpu models... + loc_shift = K.reshape(loc_shift, [-1, *self.inshape[1:]]) + + # prepare location shift + if self.indexing == 'xy': # shift the first two dimensions + loc_shift_split = tf.split(loc_shift, loc_shift.shape[-1], axis=-1) + loc_shift_lst = [loc_shift_split[1], loc_shift_split[0], *loc_shift_split[2:]] + loc_shift = tf.concat(loc_shift_lst, -1) + + if len(inputs) > 1: + assert self.out_time_pt is None, 'out_time_pt should be None if providing batch_based out_time_pt' + + # map transform across batch + out = tf.map_fn(self._single_int, [loc_shift] + inputs[1:], dtype=tf.float32) + return out + + def _single_int(self, inputs): + + vel = inputs[0] + out_time_pt = self.out_time_pt + if len(inputs) == 2: + out_time_pt = inputs[1] + return integrate_vec(vel, method=self.method, + nb_steps=self.int_steps, + ode_args=self.ode_args, + out_time_pt=out_time_pt, + odeint_fn=self.odeint_fn) + + +class Resize(Layer): + """ + N-D Resize Tensorflow / Keras Layer + Note: this is not re-shaping an existing volume, but resizing, like scipy's "Zoom" + + If you find this function useful, please cite: + Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation,Dalca AV, Guttag J, Sabuncu MR + CVPR 2018 + + Since then, we've re-written the code to be generalized to any + dimensions, and along the way wrote grid and interpolation functions + """ + + def __init__(self, + zoom_factor=None, + size=None, + interp_method='linear', + **kwargs): + """ + Parameters: + interp_method: 'linear' or 'nearest' + 'xy' indexing will have the first two entries of the flow + (along last axis) flipped compared to 'ij' indexing + """ + self.zoom_factor = zoom_factor + self.size = list(size) + self.zoom_factor0 = None + self.size0 = None + self.interp_method = interp_method + self.ndims = None + self.inshape = None + super(Resize, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["zoom_factor"] = self.zoom_factor + config["size"] = self.size + config["interp_method"] = self.interp_method + return config + + def build(self, input_shape): + """ + input_shape should be an element of list of one inputs: + input1: volume + should be a *vol_shape x N + """ + + if isinstance(input_shape[0], (list, tuple)) and len(input_shape) > 1: + raise Exception('Resize must be called on a list of length 1.') + + if isinstance(input_shape[0], (list, tuple)): + input_shape = input_shape[0] + + # set up number of dimensions + self.ndims = len(input_shape) - 2 + self.inshape = input_shape + + # check zoom_factor + if isinstance(self.zoom_factor, float): + self.zoom_factor0 = [self.zoom_factor] * self.ndims + elif self.zoom_factor is None: + self.zoom_factor0 = [0] * self.ndims + elif isinstance(self.zoom_factor, (list, tuple)): + self.zoom_factor0 = deepcopy(self.zoom_factor) + assert len(self.zoom_factor0) == self.ndims, \ + 'zoom factor length {} does not match number of dimensions {}'.format(len(self.zoom_factor), self.ndims) + else: + raise Exception('zoom_factor should be an int or a list/tuple of int (or None if size is not set to None)') + + # check size + if isinstance(self.size, int): + self.size0 = [self.size] * self.ndims + elif self.size is None: + self.size0 = [0] * self.ndims + elif isinstance(self.size, (list, tuple)): + self.size0 = deepcopy(self.size) + assert len(self.size0) == self.ndims, \ + 'size length {} does not match number of dimensions {}'.format(len(self.size0), self.ndims) + else: + raise Exception('size should be an int or a list/tuple of int (or None if zoom_factor is not set to None)') + + # confirm built + self.built = True + + super(Resize, self).build(input_shape) # Be sure to call this somewhere! + + def call(self, inputs, **kwargs): + """ + Parameters + inputs: volume or list of one volume + """ + + # check shapes + if isinstance(inputs, (list, tuple)): + assert len(inputs) == 1, "inputs has to be len 1. found: %d" % len(inputs) + vol = inputs[0] + else: + vol = inputs + + # necessary for multi_gpu models... + vol = K.reshape(vol, [-1, *self.inshape[1:]]) + + # set value of missing size or zoom_factor + if not any(self.zoom_factor0): + self.zoom_factor0 = [self.size0[i] / self.inshape[i+1] for i in range(self.ndims)] + else: + self.size0 = [int(self.inshape[f+1] * self.zoom_factor0[f]) for f in range(self.ndims)] + + # map transform across batch + return tf.map_fn(self._single_resize, vol, dtype=vol.dtype) + + def compute_output_shape(self, input_shape): + + output_shape = [input_shape[0]] + output_shape += [int(input_shape[1:-1][f] * self.zoom_factor0[f]) for f in range(self.ndims)] + output_shape += [input_shape[-1]] + return tuple(output_shape) + + def _single_resize(self, inputs): + return resize(inputs, self.zoom_factor0, self.size0, interp_method=self.interp_method) + + +# Zoom naming of resize, to match scipy's naming +Zoom = Resize + + +######################################################### +# "Local" layers -- layers with parameters at each voxel +######################################################### + +class LocalBias(Layer): + """ + Local bias layer: each pixel/voxel has its own bias operation (one parameter) + out[v] = in[v] + b + """ + + def __init__(self, my_initializer='RandomNormal', biasmult=1.0, **kwargs): + self.initializer = my_initializer + self.biasmult = biasmult + self.kernel = None + super(LocalBias, self).__init__(**kwargs) + + def get_config(self): + config = super().get_config() + config["my_initializer"] = self.initializer + config["biasmult"] = self.biasmult + return config + + def build(self, input_shape): + # Create a trainable weight variable for this layer. + self.kernel = self.add_weight(name='kernel', + shape=input_shape[1:], + initializer=self.initializer, + trainable=True) + super(LocalBias, self).build(input_shape) # Be sure to call this somewhere! + + def call(self, x, **kwargs): + return x + self.kernel * self.biasmult # weights are difference from input + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/nobrainer/ext/neuron/models.py b/nobrainer/ext/neuron/models.py new file mode 100644 index 00000000..9b5c87ed --- /dev/null +++ b/nobrainer/ext/neuron/models.py @@ -0,0 +1,768 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +import sys + +from nobrainer.ext.neuron import layers + +# third party +import numpy as np +import tensorflow as tf +import keras +import keras.layers as KL +from keras.models import Model +import keras.backend as K + + +def unet(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name='unet', + prefix=None, + feat_mult=1, + pool_size=2, + use_logp=True, + padding='same', + dilation_rate_mult=1, + activation='elu', + skip_n_concatenations=0, + use_residuals=False, + final_pred_activation='softmax', + nb_conv_per_level=1, + add_prior_layer=False, + layer_nb_feats=None, + conv_dropout=0, + batch_norm=None, + input_model=None): + """ + unet-style keras model with an overdose of parametrization. + + Parameters: + nb_features: the number of features at each convolutional level + see below for `feat_mult` and `layer_nb_feats` for modifiers to this number + input_shape: input layer shape, vector of size ndims + 1 (nb_channels) + conv_size: the convolution kernel size + nb_levels: the number of Unet levels (number of downsamples) in the "encoder" + (e.g. 4 would give you 4 levels in encoder, 4 in decoder) + nb_labels: number of output channels + name (default: 'unet'): the name of the network + prefix (default: `name` value): prefix to be added to layer names + feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels. + e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the + second layer, 64 features in the third layer, etc. + pool_size (default: 2): max pooling size (integer or list if specifying per dimension) + skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n + top levels. + use_logp: + padding: + dilation_rate_mult: + activation: + use_residuals: + final_pred_activation: + nb_conv_per_level: + add_prior_layer: + skip_n_concatenations: + layer_nb_feats: list of the number of features for each layer. Automatically used if specified + conv_dropout: dropout probability + batch_norm: + input_model: concatenate the provided input_model to this current model. + Only the first output of input_model is used. + """ + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # get encoding model + enc_model = conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + layer_nb_feats=layer_nb_feats, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model) + + # get decoder + # use_skip_connections=True makes it a u-net + lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None + dec_model = conv_dec(nb_features, + [], + nb_levels, + conv_size, + nb_labels, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=True, + skip_n_concatenations=skip_n_concatenations, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + final_pred_activation='linear' if add_prior_layer else final_pred_activation, + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + layer_nb_feats=lnf, + conv_dropout=conv_dropout, + input_model=enc_model) + final_model = dec_model + + if add_prior_layer: + final_model = add_prior(dec_model, + [*input_shape[:-1], nb_labels], + name=model_name + '_prior', + use_logp=use_logp, + final_pred_activation=final_pred_activation) + + return final_model + + +def ae(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + enc_size, + name='ae', + feat_mult=1, + pool_size=2, + padding='same', + activation='elu', + use_residuals=False, + nb_conv_per_level=1, + batch_norm=None, + enc_batch_norm=None, + ae_type='conv', # 'dense', or 'conv' + enc_lambda_layers=None, + add_prior_layer=False, + use_logp=True, + conv_dropout=0, + include_mu_shift_layer=False, + single_model=False, # whether to return a single model, or a tuple of models that can be stacked. + final_pred_activation='softmax', + do_vae=False, + input_model=None): + """Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True).""" + + # naming + model_name = name + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # get encoding model + enc_model = conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model) + + # middle AE structure + if single_model: + in_input_shape = None + in_model = enc_model + else: + in_input_shape = enc_model.output.shape.as_list()[1:] + in_model = None + mid_ae_model = single_ae(enc_size, + in_input_shape, + conv_size=conv_size, + name=model_name, + ae_type=ae_type, + input_model=in_model, + batch_norm=enc_batch_norm, + enc_lambda_layers=enc_lambda_layers, + include_mu_shift_layer=include_mu_shift_layer, + do_vae=do_vae) + + # decoder + if single_model: + in_input_shape = None + in_model = mid_ae_model + else: + in_input_shape = mid_ae_model.output.shape.as_list()[1:] + in_model = None + dec_model = conv_dec(nb_features, + in_input_shape, + nb_levels, + conv_size, + nb_labels, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=False, + padding=padding, + activation=activation, + use_residuals=use_residuals, + final_pred_activation='linear', + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + conv_dropout=conv_dropout, + input_model=in_model) + + if add_prior_layer: + dec_model = add_prior(dec_model, + [*input_shape[:-1], nb_labels], + name=model_name, + prefix=model_name + '_prior', + use_logp=use_logp, + final_pred_activation=final_pred_activation) + + if single_model: + return dec_model + else: + return dec_model, mid_ae_model, enc_model + + +def conv_enc(nb_features, + input_shape, + nb_levels, + conv_size, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + dilation_rate_mult=1, + padding='same', + activation='elu', + layer_nb_feats=None, + use_residuals=False, + nb_conv_per_level=2, + conv_dropout=0, + batch_norm=None, + input_model=None): + """Fully Convolutional Encoder""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # first layer: input + name = '%s_input' % prefix + if input_model is None: + input_tensor = KL.Input(shape=input_shape, name=name) + last_tensor = input_tensor + else: + input_tensor = input_model.inputs + last_tensor = input_model.outputs + if isinstance(last_tensor, list): + last_tensor = last_tensor[0] + + # volume size data + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + pool_size = (pool_size,) * ndims + + # prepare layers + convL = getattr(KL, 'Conv%dD' % ndims) + conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'} + maxpool = getattr(KL, 'MaxPooling%dD' % ndims) + + # down arm: + # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers + lfidx = 0 # level feature index + for level in range(nb_levels): + lvl_first_tensor = last_tensor + nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int) + conv_kwargs['dilation_rate'] = dilation_rate_mult ** level + + for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end + if layer_nb_feats is not None: # None or List of all the feature numbers + nb_lvl_feats = layer_nb_feats[lfidx] + lfidx += 1 + + name = '%s_conv_downarm_%d_%d' % (prefix, level, conv) + if conv < (nb_conv_per_level - 1) or (not use_residuals): + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) + else: # no activation + last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) + + if conv_dropout > 0: + # conv dropout along feature space only + name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv) + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) + + if use_residuals: + convarm_layer = last_tensor + + # the "add" layer is the original input + # However, it may not have the right number of features to be added + nb_feats_in = lvl_first_tensor.get_shape()[-1] + nb_feats_out = convarm_layer.get_shape()[-1] + add_layer = lvl_first_tensor + if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): + name = '%s_expand_down_merge_%d' % (prefix, level) + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor) + add_layer = last_tensor + + if conv_dropout > 0: + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) + + name = '%s_res_down_merge_%d' % (prefix, level) + last_tensor = KL.add([add_layer, convarm_layer], name=name) + + name = '%s_res_down_merge_act_%d' % (prefix, level) + last_tensor = KL.Activation(activation, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_down_%d' % (prefix, level) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # max pool if we're not at the last level + if level < (nb_levels - 1): + name = '%s_maxpool_%d' % (prefix, level) + last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) + return model + + +def conv_dec(nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + use_skip_connections=False, + skip_n_concatenations=0, + padding='same', + dilation_rate_mult=1, + activation='elu', + use_residuals=False, + final_pred_activation='softmax', + nb_conv_per_level=2, + layer_nb_feats=None, + batch_norm=None, + conv_dropout=0, + input_model=None): + """Fully Convolutional Decoder""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # if using skip connections, make sure need to use them. + if use_skip_connections: + assert input_model is not None, "is using skip connections, tensors dictionary is required" + + # first layer: input + input_name = '%s_input' % prefix + if input_model is None: + input_tensor = KL.Input(shape=input_shape, name=input_name) + last_tensor = input_tensor + else: + input_tensor = input_model.input + last_tensor = input_model.output + input_shape = last_tensor.shape.as_list()[1:] + + # vol size info + ndims = len(input_shape) - 1 + if isinstance(pool_size, int): + if ndims > 1: + pool_size = (pool_size,) * ndims + + # prepare layers + convL = getattr(KL, 'Conv%dD' % ndims) + conv_kwargs = {'padding': padding, 'activation': activation} + upsample = getattr(KL, 'UpSampling%dD' % ndims) + + # up arm: + # nb_levels - 1 layers of Deconvolution3D + # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu + lfidx = 0 + for level in range(nb_levels - 1): + nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int) + conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level) + + # upsample matching the max pooling layers size + name = '%s_up_%d' % (prefix, nb_levels + level) + last_tensor = upsample(size=pool_size, name=name)(last_tensor) + up_tensor = last_tensor + + # merge layers combining previous layer + if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)): + conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1) + cat_tensor = input_model.get_layer(conv_name).output + name = '%s_merge_%d' % (prefix, nb_levels + level) + last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name) + + # convolution layers + for conv in range(nb_conv_per_level): + if layer_nb_feats is not None: + nb_lvl_feats = layer_nb_feats[lfidx] + lfidx += 1 + + name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv) + if conv < (nb_conv_per_level - 1) or (not use_residuals): + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) + else: + last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) + + if conv_dropout > 0: + name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv) + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) + + # residual block + if use_residuals: + + # the "add" layer is the original input + # However, it may not have the right number of features to be added + add_layer = up_tensor + nb_feats_in = add_layer.get_shape()[-1] + nb_feats_out = last_tensor.get_shape()[-1] + if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): + name = '%s_expand_up_merge_%d' % (prefix, level) + add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer) + + if conv_dropout > 0: + noise_shape = [None, *[1] * ndims, nb_lvl_feats] + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) + + name = '%s_res_up_merge_%d' % (prefix, level) + last_tensor = KL.add([last_tensor, add_layer], name=name) + + name = '%s_res_up_merge_act_%d' % (prefix, level) + last_tensor = KL.Activation(activation, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_up_%d' % (prefix, level) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # Compute likelihood prediction (no activation yet) + name = '%s_likelihood' % prefix + last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor) + like_tensor = last_tensor + + # output prediction layer + # we use a softmax to compute P(L_x|I) where x is each location + if final_pred_activation == 'softmax': + name = '%s_prediction' % prefix + softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1) + pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor) + + # otherwise create a layer that does nothing. + else: + name = '%s_prediction' % prefix + pred_tensor = KL.Activation('linear', name=name)(like_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name) + return model + + +def add_prior(input_model, + prior_shape, + name='prior_model', + prefix=None, + use_logp=True, + final_pred_activation='softmax'): + """ + Append post-prior layer to a given model + """ + + # naming + model_name = name + if prefix is None: + prefix = model_name + + # prior input layer + prior_input_name = '%s-input' % prefix + prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name) + prior_tensor_input = prior_tensor + like_tensor = input_model.output + + # operation varies depending on whether we log() prior or not. + if use_logp: + print("Breaking change: use_logp option now requires log input!", file=sys.stderr) + merge_op = KL.add + + else: + # using sigmoid to get the likelihood values between 0 and 1 + # note: they won't add up to 1. + name = '%s_likelihood_sigmoid' % prefix + like_tensor = KL.Activation('sigmoid', name=name)(like_tensor) + merge_op = KL.multiply + + # merge the likelihood and prior layers into posterior layer + name = '%s_posterior' % prefix + post_tensor = merge_op([prior_tensor, like_tensor], name=name) + + # output prediction layer + # we use a softmax to compute P(L_x|I) where x is each location + pred_name = '%s_prediction' % prefix + if final_pred_activation == 'softmax': + assert use_logp, 'cannot do softmax when adding prior via P()' + print("using final_pred_activation %s for %s" % (final_pred_activation, model_name)) + softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1) + pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor) + + else: + pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor) + + # create the model + model_inputs = [*input_model.inputs, prior_tensor_input] + model = Model(inputs=model_inputs, outputs=[pred_tensor], name=model_name) + + # compile + return model + + +def single_ae(enc_size, + input_shape, + name='single_ae', + prefix=None, + ae_type='dense', # 'dense', or 'conv' + conv_size=None, + input_model=None, + enc_lambda_layers=None, + batch_norm=True, + padding='same', + activation=None, + include_mu_shift_layer=False, + do_vae=False): + """single-layer Autoencoder (i.e. input - encoding - output""" + + # naming + model_name = name + if prefix is None: + prefix = model_name + + if enc_lambda_layers is None: + enc_lambda_layers = [] + + # prepare input + input_name = '%s_input' % prefix + if input_model is None: + assert input_shape is not None, 'input_shape of input_model is necessary' + input_tensor = KL.Input(shape=input_shape, name=input_name) + last_tensor = input_tensor + else: + input_tensor = input_model.input + last_tensor = input_model.output + input_shape = last_tensor.shape.as_list()[1:] + input_nb_feats = last_tensor.shape.as_list()[-1] + + # prepare conv type based on input + ndims = len(input_shape) - 1 + if ae_type == 'conv': + convL = getattr(KL, 'Conv%dD' % ndims) + assert conv_size is not None, 'with conv ae, need conv_size' + conv_kwargs = {'padding': padding, 'activation': activation} + enc_size_str = None + + # if want to go through a dense layer in the middle of the U, need to: + # - flatten last layer if not flat + # - do dense encoding and decoding + # - unflatten (reshape spatially) at end + else: # ae_type == 'dense' + if len(input_shape) > 1: + name = '%s_ae_%s_down_flat' % (prefix, ae_type) + last_tensor = KL.Flatten(name=name)(last_tensor) + convL = conv_kwargs = None + assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer" + enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1] + + # recall this layer + pre_enc_layer = last_tensor + + # encoding layer + if ae_type == 'dense': + name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str) + last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) + + else: # convolution + + # convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats] + assert len(enc_size) == len(input_shape), \ + "encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape)) + + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + + name = '%s_ae_mu_enc_conv' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + name = '%s_ae_mu_enc' % prefix + zf = [enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size) - 1)] + last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) + + elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck + name = '%s_ae_mu_enc' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) + + else: + name = '%s_ae_mu_enc' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + if include_mu_shift_layer: + # shift + name = '%s_ae_mu_shift' % prefix + last_tensor = layers.LocalBias(name=name)(last_tensor) + + # encoding clean-up layers + for layer_fcn in enc_lambda_layers: + lambda_name = layer_fcn.__name__ + name = '%s_ae_mu_%s' % (prefix, lambda_name) + last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_ae_mu_bn' % prefix + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # have a simple layer that does nothing to have a clear name before sampling + name = '%s_ae_mu' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) + + # if doing variational AE, will need the sigma layer as well. + if do_vae: + mu_tensor = last_tensor + + # encoding layer + if ae_type == 'dense': + name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str) + last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) + + else: + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + + assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..." + name = '%s_ae_sigma_enc_conv' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + name = '%s_ae_sigma_enc' % prefix + resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1]) + last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) + + elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck + name = '%s_ae_sigma_enc' % prefix + last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer) + # cannot use lambda, then mu and sigma will be same layer. + # last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) + + else: + name = '%s_ae_sigma_enc' % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + + # encoding clean-up layers + for layer_fcn in enc_lambda_layers: + lambda_name = layer_fcn.__name__ + name = '%s_ae_sigma_%s' % (prefix, lambda_name) + last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) + + if batch_norm is not None: + name = '%s_ae_sigma_bn' % prefix + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # have a simple layer that does nothing to have a clear name before sampling + name = '%s_ae_sigma' % prefix + last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) + + logvar_tensor = last_tensor + + # VAE sampling + sampler = _VAESample().sample_z + + name = '%s_ae_sample' % prefix + last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor]) + + if include_mu_shift_layer: + # shift + name = '%s_ae_sample_shift' % prefix + last_tensor = layers.LocalBias(name=name)(last_tensor) + + # decoding layer + if ae_type == 'dense': + name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str) + last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor) + + # unflatten if dense method + if len(input_shape) > 1: + name = '%s_ae_%s_dec' % (prefix, ae_type) + last_tensor = KL.Reshape(input_shape, name=name)(last_tensor) + + else: + + if list(enc_size)[:-1] != list(input_shape)[:-1] and \ + all([f is not None for f in input_shape[:-1]]) and \ + all([f is not None for f in enc_size[:-1]]): + name = '%s_ae_mu_dec' % prefix + zf = [last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] for f in range(len(enc_size) - 1)] + last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) + + name = '%s_ae_%s_dec' % (prefix, ae_type) + last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor) + + if batch_norm is not None: + name = '%s_bn_ae_%s_dec' % (prefix, ae_type) + last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) + + # create the model and return + model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) + return model + + +############################################################################### +# Helper function +############################################################################### + +class _VAESample: + def __init__(self): + pass + + def sample_z(self, args): + mu, log_var = args + shape = K.shape(mu) + eps = K.random_normal(shape=shape, mean=0., stddev=1.) + return mu + K.exp(log_var / 2) * eps diff --git a/nobrainer/ext/neuron/utils.py b/nobrainer/ext/neuron/utils.py new file mode 100644 index 00000000..1162b51c --- /dev/null +++ b/nobrainer/ext/neuron/utils.py @@ -0,0 +1,548 @@ +""" +tensorflow/keras utilities for the neuron project + +If you use this code, please cite +Dalca AV, Guttag J, Sabuncu MR +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +CVPR 2018 + +or for the transformation/interpolation related functions: + +Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration +Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu +MICCAI 2018. + +Contact: adalca [at] csail [dot] mit [dot] edu +License: GPLv3 +""" + +import itertools +import numpy as np +import tensorflow as tf +import keras.backend as K + + +def interpn(vol, loc, interp_method='linear'): + """ + N-D gridded interpolation in tensorflow + + vol can have more dimensions than loc[i], in which case loc[i] acts as a slice + for the first dimensions + + Parameters: + vol: volume with size vol_shape or [*vol_shape, nb_features] + loc: an N-long list of N-D Tensors (the interpolation locations) for the new grid + each tensor has to have the same size (but not nec. same size as vol) + or a tensor of size [*new_vol_shape, D] + interp_method: interpolation type 'linear' (default) or 'nearest' + + Returns: + new interpolated volume of the same size as the entries in loc + """ + + if isinstance(loc, (list, tuple)): + loc = tf.stack(loc, -1) + nb_dims = loc.shape[-1] + + if len(vol.shape) not in [nb_dims, nb_dims + 1]: + raise Exception("Number of loc Tensors %d does not match volume dimension %d" + % (nb_dims, len(vol.shape[:-1]))) + + if nb_dims > len(vol.shape): + raise Exception("Loc dimension %d does not match volume dimension %d" + % (nb_dims, len(vol.shape))) + + if len(vol.shape) == nb_dims: + vol = K.expand_dims(vol, -1) + + # flatten and float location Tensors + loc = tf.cast(loc, 'float32') + + if isinstance(vol.shape, tf.TensorShape): + volshape = vol.shape.as_list() + else: + volshape = vol.shape + + # interpolate + if interp_method == 'linear': + loc0 = tf.floor(loc) + + # clip values + max_loc = [d - 1 for d in vol.get_shape().as_list()] + clipped_loc = [tf.clip_by_value(loc[..., d], 0, max_loc[d]) for d in range(nb_dims)] + loc0lst = [tf.clip_by_value(loc0[..., d], 0, max_loc[d]) for d in range(nb_dims)] + + # get other end of point cube + loc1 = [tf.clip_by_value(loc0lst[d] + 1, 0, max_loc[d]) for d in range(nb_dims)] + locs = [[tf.cast(f, 'int32') for f in loc0lst], [tf.cast(f, 'int32') for f in loc1]] + + # compute the difference between the upper value and the original value + # differences are basically 1 - (pt - floor(pt)) + # because: floor(pt) + 1 - pt = 1 + (floor(pt) - pt) = 1 - (pt - floor(pt)) + diff_loc1 = [loc1[d] - clipped_loc[d] for d in range(nb_dims)] + diff_loc0 = [1 - d for d in diff_loc1] + weights_loc = [diff_loc1, diff_loc0] # note reverse ordering since weights are inverse of diff. + + # go through all the cube corners, indexed by a ND binary vector + # e.g. [0, 0] means this "first" corner in a 2-D "cube" + cube_pts = list(itertools.product([0, 1], repeat=nb_dims)) + interp_vol = 0 + + for c in cube_pts: + # get nd values + # note re: indices above volumes via https://github.com/tensorflow/tensorflow/issues/15091 + # It works on GPU because we do not perform index validation checking on GPU -- it's too + # expensive. Instead we fill the output with zero for the corresponding value. The CPU + # version caught the bad index and returned the appropriate error. + subs = [locs[c[d]][d] for d in range(nb_dims)] + + idx = sub2ind(vol.shape[:-1], subs) + vol_val = tf.gather(tf.reshape(vol, [-1, volshape[-1]]), idx) + + # get the weight of this cube_pt based on the distance + # if c[d] is 0 --> want weight = 1 - (pt - floor[pt]) = diff_loc1 + # if c[d] is 1 --> want weight = pt - floor[pt] = diff_loc0 + wts_lst = [weights_loc[c[d]][d] for d in range(nb_dims)] + wt = prod_n(wts_lst) + wt = K.expand_dims(wt, -1) + + # compute final weighted value for each cube corner + interp_vol += wt * vol_val + + else: + assert interp_method == 'nearest' + roundloc = tf.cast(tf.round(loc), 'int32') + + # clip values + max_loc = [tf.cast(d - 1, 'int32') for d in vol.shape] + roundloc = [tf.clip_by_value(roundloc[..., d], 0, max_loc[d]) for d in range(nb_dims)] + + # get values + idx = sub2ind(vol.shape[:-1], roundloc) + interp_vol = tf.gather(tf.reshape(vol, [-1, vol.shape[-1]]), idx) + + return interp_vol + + +def resize(vol, zoom_factor, new_shape, interp_method='linear'): + """ + if zoom_factor is a list, it will determine the ndims, in which case vol has to be of length ndims or ndims + 1 + + if zoom_factor is an integer, then vol must be of length ndims + 1 + + new_shape should be a list of length ndims + + """ + + if isinstance(zoom_factor, (list, tuple)): + ndims = len(zoom_factor) + vol_shape = vol.shape[:ndims] + assert len(vol_shape) in (ndims, ndims + 1), \ + "zoom_factor length %d does not match ndims %d" % (len(vol_shape), ndims) + else: + vol_shape = vol.shape[:-1] + ndims = len(vol_shape) + zoom_factor = [zoom_factor] * ndims + + # get grid for new shape + grid = volshape_to_ndgrid(new_shape) + grid = [tf.cast(f, 'float32') for f in grid] + offset = [grid[f] / zoom_factor[f] - grid[f] for f in range(ndims)] + offset = tf.stack(offset, ndims) + + # transform + return transform(vol, offset, interp_method) + + +zoom = resize + + +def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing='ij'): + """ + transform an affine matrix to a dense location shift tensor in tensorflow + + Algorithm: + - get grid and shift grid to be centered at the center of the image (optionally) + - apply affine matrix to each index. + - subtract grid + + Parameters: + affine_matrix: ND+1 x ND+1 or ND x ND+1 matrix (Tensor) + volshape: 1xN Nd Tensor of the size of the volume. + shift_center (optional) + indexing + + Returns: + shift field (Tensor) of size *volshape x N + """ + + if isinstance(volshape, tf.TensorShape): + volshape = volshape.as_list() + + if affine_matrix.dtype != 'float32': + affine_matrix = tf.cast(affine_matrix, 'float32') + + nb_dims = len(volshape) + + if len(affine_matrix.shape) == 1: + if len(affine_matrix) != (nb_dims * (nb_dims + 1)): + raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).' + 'Got len %d' % len(affine_matrix)) + + affine_matrix = tf.reshape(affine_matrix, [nb_dims, nb_dims + 1]) + + if not (affine_matrix.shape[0] in [nb_dims, nb_dims + 1] and affine_matrix.shape[1] == (nb_dims + 1)): + raise Exception('Affine matrix shape should match' + '%d+1 x %d+1 or ' % (nb_dims, nb_dims) + + '%d x %d+1.' % (nb_dims, nb_dims) + + 'Got: ' + str(volshape)) + + # list of volume ndgrid + # N-long list, each entry of shape volshape + mesh = volshape_to_meshgrid(volshape, indexing=indexing) + mesh = [tf.cast(f, 'float32') for f in mesh] + + if shift_center: + mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] + + # add an all-ones entry and transform into a large matrix + flat_mesh = [flatten(f) for f in mesh] + flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32')) + mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # 4 x nb_voxels + + # compute locations + loc_matrix = tf.matmul(affine_matrix, mesh_matrix) # N+1 x nb_voxels + loc_matrix = tf.transpose(loc_matrix[:nb_dims, :]) # nb_voxels x N + loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims]) # *volshape x N + + # get shifts and return + return loc - tf.stack(mesh, axis=nb_dims) + + +def combine_non_linear_and_aff_to_shift(transform_list, volshape, shift_center=True, indexing='ij'): + """ + transform an affine matrix to a dense location shift tensor in tensorflow + + Algorithm: + - get grid and shift grid to be centered at the center of the image (optionally) + - apply affine matrix to each index. + - subtract grid + + Parameters: + transform_list: list of non-linear tensor (size of volshape) and affine ND+1 x ND+1 or ND x ND+1 tensor + volshape: 1xN Nd Tensor of the size of the volume. + shift_center (optional) + indexing + + Returns: + shift field (Tensor) of size *volshape x N + """ + + if isinstance(volshape, tf.TensorShape): + volshape = volshape.as_list() + + # convert transforms to floats + for i in range(len(transform_list)): + if transform_list[i].dtype != 'float32': + transform_list[i] = tf.cast(transform_list[i], 'float32') + + nb_dims = len(volshape) + + # transform affine to matrix if given as vector + if len(transform_list[1].shape) == 1: + if len(transform_list[1]) != (nb_dims * (nb_dims + 1)): + raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).' + 'Got len %d' % len(transform_list[1])) + + transform_list[1] = tf.reshape(transform_list[1], [nb_dims, nb_dims + 1]) + + if not (transform_list[1].shape[0] in [nb_dims, nb_dims + 1] and transform_list[1].shape[1] == (nb_dims + 1)): + raise Exception('Affine matrix shape should match' + '%d+1 x %d+1 or ' % (nb_dims, nb_dims) + + '%d x %d+1.' % (nb_dims, nb_dims) + + 'Got: ' + str(volshape)) + + # list of volume ndgrid + # N-long list, each entry of shape volshape + mesh = volshape_to_meshgrid(volshape, indexing=indexing) + mesh = [tf.cast(f, 'float32') for f in mesh] + + if shift_center: + mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] + + # add an all-ones entry and transform into a large matrix + # non_linear_mesh = tf.unstack(transform_list[0], axis=3) + non_linear_mesh = tf.unstack(transform_list[0], axis=-1) + flat_mesh = [flatten(mesh[i]+non_linear_mesh[i]) for i in range(len(mesh))] + flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32')) + mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # N+1 x nb_voxels + + # compute locations + loc_matrix = tf.matmul(transform_list[1], mesh_matrix) # N+1 x nb_voxels + loc_matrix = tf.transpose(loc_matrix[:nb_dims, :]) # nb_voxels x N + loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims]) # *volshape x N + + # get shifts and return + return loc - tf.stack(mesh, axis=nb_dims) + + +def transform(vol, loc_shift, interp_method='linear', indexing='ij'): + """ + transform interpolation N-D volumes (features) given shifts at each location in tensorflow + + Essentially interpolates volume vol at locations determined by loc_shift. + This is a spatial transform in the sense that at location [x] we now have the data from, + [x + shift] so we've moved data. + + Parameters: + vol: volume with size vol_shape or [*vol_shape, nb_features] + loc_shift: shift volume [*new_vol_shape, N] + interp_method (default:'linear'): 'linear', 'nearest' + indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian). + In general, prefer to leave this 'ij' + + Return: + new interpolated volumes in the same size as loc_shift[0] + """ + + # parse shapes + if isinstance(loc_shift.shape, tf.TensorShape): + volshape = loc_shift.shape[:-1].as_list() + else: + volshape = loc_shift.shape[:-1] + nb_dims = len(volshape) + + # location should be meshed and delta + mesh = volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh + loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)] + + # test single + return interpn(vol, loc, interp_method=interp_method) + + +def integrate_vec(vec, time_dep=False, method='ss', **kwargs): + """ + Integrate (stationary of time-dependent) vector field (N-D Tensor) in tensorflow + + Aside from directly using tensorflow's numerical integration odeint(), also implements + "scaling and squaring", and quadrature. Note that the diff. equation given to odeint + is the one used in quadrature. + + Parameters: + vec: the Tensor field to integrate. + If vol_size is the size of the intrinsic volume, and vol_ndim = len(vol_size), + then vector shape (vec_shape) should be + [vol_size, vol_ndim] (if stationary) + [vol_size, vol_ndim, nb_time_steps] (if time dependent) + time_dep: bool whether vector is time dependent + method: 'scaling_and_squaring' or 'ss' or 'quadrature' + + if using 'scaling_and_squaring': currently only supports integrating to time point 1. + nb_steps int number of steps. Note that this means the vec field gets broken own to 2**nb_steps. + so nb_steps of 0 means integral = vec. + + Returns: + int_vec: integral of vector field with same shape as the input + """ + + if method not in ['ss', 'scaling_and_squaring', 'ode', 'quadrature']: + raise ValueError("method has to be 'scaling_and_squaring' or 'ode'. found: %s" % method) + + if method in ['ss', 'scaling_and_squaring']: + nb_steps = kwargs['nb_steps'] + assert nb_steps >= 0, 'nb_steps should be >= 0, found: %d' % nb_steps + + if time_dep: + svec = K.permute_dimensions(vec, [-1, *range(0, vec.shape[-1] - 1)]) + assert 2 ** nb_steps == svec.shape[0], "2**nb_steps and vector shape don't match" + + svec = svec / (2 ** nb_steps) + for _ in range(nb_steps): + svec = svec[0::2] + tf.map_fn(transform, svec[1::2, :], svec[0::2, :]) + + disp = svec[0, :] + + else: + vec = vec / (2 ** nb_steps) + for _ in range(nb_steps): + vec += transform(vec, vec) + disp = vec + + else: # method == 'quadrature': + nb_steps = kwargs['nb_steps'] + assert nb_steps >= 1, 'nb_steps should be >= 1, found: %d' % nb_steps + + vec = vec / nb_steps + + if time_dep: + disp = vec[..., 0] + for si in range(nb_steps - 1): + disp += transform(vec[..., si + 1], disp) + else: + disp = vec + for _ in range(nb_steps - 1): + disp += transform(vec, disp) + + return disp + + +def volshape_to_ndgrid(volshape, **kwargs): + """ + compute Tensor ndgrid from a volume size + + Parameters: + volshape: the volume size + + Returns: + A list of Tensors + + See Also: + ndgrid + """ + + isint = [float(d).is_integer() for d in volshape] + if not all(isint): + raise ValueError("volshape needs to be a list of integers") + + linvec = [tf.range(0, d) for d in volshape] + return ndgrid(*linvec, **kwargs) + + +def volshape_to_meshgrid(volshape, **kwargs): + """ + compute Tensor meshgrid from a volume size + + Parameters: + volshape: the volume size + + Returns: + A list of Tensors + + See Also: + tf.meshgrid, meshgrid, ndgrid, volshape_to_ndgrid + """ + + isint = [float(d).is_integer() for d in volshape] + if not all(isint): + raise ValueError("volshape needs to be a list of integers") + + linvec = [tf.range(0, d) for d in volshape] + return meshgrid(*linvec, **kwargs) + + +def ndgrid(*args, **kwargs): + """ + broadcast Tensors on an N-D grid with ij indexing + uses meshgrid with ij indexing + + Parameters: + *args: Tensors with rank 1 + **args: "name" (optional) + + Returns: + A list of Tensors + + """ + return meshgrid(*args, indexing='ij', **kwargs) + + +def meshgrid(*args, **kwargs): + """ + + meshgrid code that builds on (copies) tensorflow's meshgrid but dramatically + improves runtime by changing the last step to tiling instead of multiplication. + https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/python/ops/array_ops.py#L1921 + + Broadcasts parameters for evaluation on an N-D grid. + Given N one-dimensional coordinate arrays `*args`, returns a list `outputs` + of N-D coordinate arrays for evaluating expressions on an N-D grid. + Notes: + `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions. + When the `indexing` argument is set to 'xy' (the default), the broadcasting + instructions for the first two dimensions are swapped. + Examples: + Calling `X, Y = meshgrid(x, y)` with the tensors + ```python + x = [1, 2, 3] + y = [4, 5, 6] + X, Y = meshgrid(x, y) + # X = [[1, 2, 3], + # [1, 2, 3], + # [1, 2, 3]] + # Y = [[4, 4, 4], + # [5, 5, 5], + # [6, 6, 6]] + ``` + Args: + *args: `Tensor`s with rank 1. + **kwargs: + - indexing: Either 'xy' or 'ij' (optional, default: 'xy'). + - name: A name for the operation (optional). + Returns: + outputs: A list of N `Tensor`s with rank N. + Raises: + TypeError: When no keyword arguments (kwargs) are passed. + ValueError: When indexing keyword argument is not one of `xy` or `ij`. + """ + + indexing = kwargs.pop("indexing", "xy") + if kwargs: + key = list(kwargs.keys())[0] + raise TypeError("'{}' is an invalid keyword argument " + "for this function".format(key)) + + if indexing not in ("xy", "ij"): + raise ValueError("indexing parameter must be either 'xy' or 'ij'") + + # with ops.name_scope(name, "meshgrid", args) as name: + ndim = len(args) + s0 = (1,) * ndim + + # Prepare reshape by inserting dimensions with size 1 where needed + output = [] + for i, x in enumerate(args): + output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::]))) + # Create parameters for broadcasting each tensor to the full size + shapes = [tf.size(x) for x in args] + sz = [x.get_shape().as_list()[0] for x in args] + + # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype + if indexing == "xy" and ndim > 1: + output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2)) + output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2)) + shapes[0], shapes[1] = shapes[1], shapes[0] + sz[0], sz[1] = sz[1], sz[0] + + for i in range(len(output)): + stack_sz = [*sz[:i], 1, *sz[(i + 1):]] + if indexing == 'xy' and ndim > 1 and i < 2: + stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0] + output[i] = tf.tile(output[i], tf.stack(stack_sz)) + return output + + +def flatten(v): + """flatten Tensor v""" + + return tf.reshape(v, [-1]) + + +def prod_n(lst): + prod = lst[0] + for p in lst[1:]: + prod *= p + return prod + + +def sub2ind(siz, subs): + """assumes column-order major""" + # subs is a list + assert len(siz) == len(subs), 'found inconsistent siz and subs: %d %d' % (len(siz), len(subs)) + + k = np.cumprod(siz[::-1]) + + ndx = subs[-1] + for i, v in enumerate(subs[:-1][::-1]): + ndx = ndx + v * k[i] + + return ndx diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index 68a797d3..27d83502 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -12,6 +12,7 @@ from .progressivegan import progressivegan from .unet import unet from .unetr import unetr +from .labels_to_image_model import labels_to_image_model __all__ = ["get", "list_available_models"] @@ -27,7 +28,8 @@ "attention_unet_with_inception": attention_unet_with_inception, "unetr": unetr, "variational_meshnet": variational_meshnet, - "bayesian_vnet": bayesian_vnet + "bayesian_vnet": bayesian_vnet, + "synthgenerator": labels_to_image_model } diff --git a/nobrainer/models/labels_to_image_model.py b/nobrainer/models/labels_to_image_model.py new file mode 100644 index 00000000..0083b2d5 --- /dev/null +++ b/nobrainer/models/labels_to_image_model.py @@ -0,0 +1,297 @@ +""" +If you use this code, please cite one of the SynthSeg papers: +https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import numpy as np +import tensorflow as tf +import keras.layers as KL +from keras.models import Model + +# third-party imports +from nobrainer.ext.lab2im import utils +from nobrainer.ext.lab2im import layers +from nobrainer.ext.lab2im import edit_tensors as l2i_et +from nobrainer.ext.lab2im.edit_volumes import get_ras_axes + + +def labels_to_image_model(labels_shape, + n_channels, + generation_labels, + output_labels, + n_neutral_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + flipping=True, + aff=None, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3., + nonlin_scale=.0625, + randomise_res=False, + max_res_iso=4., + max_res_aniso=8., + data_res=None, + thickness=None, + bias_field_std=.5, + bias_scale=.025, + return_gradients=False): + """ + This function builds a keras/tensorflow model to generate images from provided label maps. + The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. + The model will take as inputs: + -a label map + -a vector containing the means of the Gaussian Mixture Model for each label, + -a vector containing the standard deviations of the Gaussian Mixture Model for each label, + The model returns: + -the generated image normalised between 0 and 1. + -the corresponding label map, with only the labels present in output_labels (the other are reset to zero). + # IMPORTANT !!! + # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), + # these values refer to the RAS axes. + :param labels_shape: shape of the input label maps. Can be a sequence or a 1d numpy array. + :param n_channels: number of channels to be synthesised. + :param generation_labels: (optional) list of all possible label values in the input label maps. + Default is None, where the label values are directly gotten from the provided label maps. + If not None, can be a sequence or a 1d numpy array. It should be organised as follows: background label first, then + non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same hemisphere (can be left or right), + and finally all the corresponding contralateral structures (in the same order). + :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in the + label maps returned by this model, i.e. all occurrences of generation_labels[i] in the input label maps will be + converted to output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + :param n_neutral_labels: number of non-sided generation labels. + :param atlas_res: resolution of the input label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param target_res: target resolution of the generated images and corresponding label maps. + Can be a number (isotropic resolution), a sequence, or a 1d numpy array. + :param output_shape: (optional) desired shape of the output image, obtained by randomly cropping the generated image + Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. + Default is None, where no cropping is performed. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites output_shape + if necessary. Can be an integer (same size in all dimensions), a sequence, or a 1d numpy array. + :param flipping: (optional) whether to introduce right/left random flipping + :param aff: (optional) example of an (n_dims+1)x(n_dims+1) affine matrix of one of the input label map. + Used to find brain's right/left axis. Should be given if flipping is True. + :param scaling_bounds: (optional) range of the random scaling to apply at each mini-batch. The scaling factor for + each dimension is sampled from a uniform distribution of predefined bounds. Can either be: + 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds + [1-scaling_bounds, 1+scaling_bounds] for each dimension. + 2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds + (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension. + 3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution + of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. + 4) False, in which case scaling is completely turned off. + Default is scaling_bounds = 0.2 (case 1) + :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1 + and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]]. + Default is rotation_bounds = 15. + :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. + :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we + encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator). + :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we + sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the elastic + deformation off. + :param nonlin_scale: (optional) if nonlin_std is strictly positive, factor between the shapes of the input + label maps and the shape of the input non-linear tensor. + :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and + 2) resampled to high resolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm]. + In that process, the images generated by sampling the GMM are 1) blurred at the sampled LR, 2) downsampled at LR, + and 3) resampled at target_resolution. + :param max_res_iso: (optional) If randomise_res is True, this enables to control the upper bound of the uniform + distribution from which we sample the random resolution U(min_res, max_res_iso), where min_res is the resolution of + the input label maps. Must be a number, and default is 4. Set to None to deactivate it, but if randomise_res is + True, at least one of max_res_iso or max_res_aniso must be given. + :param max_res_aniso: If randomise_res is True, this enables to downsample the input volumes to a random LR in + only 1 (random) direction. This is done by randomly selecting a direction i in the range [0, n_dims-1], and sampling + a value in the corresponding uniform distribution U(min_res[i], max_res_aniso[i]), where min_res is the resolution + of the input label maps. Can be a number, a sequence, or a 1d numpy array. Set to None to deactivate it, but if + randomise_res is True, at least one of max_res_iso or max_res_aniso must be given. + :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled when + randomise_res is True. This triggers a blurring to mimic the specified acquisition resolution, but the downsampling + is optional (see param downsample). Default for data_res is None, where images are slightly blurred. + If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, a 1d + numpy array, or the path to a 1d numpy array. In the multi-modal case, it should be given as a numpy array (or a + path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel. + :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low + resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res. + :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with a + bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full size, + and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the std dev of + the normal distribution from which we sample the first tensor. Set to 0 to deactivate bias field corruption. + :param bias_scale: (optional) If bias_field_std is strictly positive, this designates the ratio between the + size of the input label maps and the size of the first sampled tensor for synthesising the bias field. + :param return_gradients: (optional) whether to return the synthetic image or the magnitude of its spatial gradient + (computed with Sobel kernels). + """ + + # reformat resolutions + labels_shape = utils.reformat_to_list(labels_shape) + n_dims, _ = utils.get_dims(labels_shape) + atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims, n_channels) + data_res = atlas_res if data_res is None else utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) + thickness = data_res if thickness is None else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + atlas_res = atlas_res[0] + target_res = atlas_res if target_res is None else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + + # get shapes + crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + + # define model inputs + labels_input = KL.Input(shape=labels_shape + [1], name='labels_input', dtype='int32') + means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') + stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='std_devs_input') + list_inputs = [labels_input, means_input, stds_input] + + # deform labels + labels = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method='nearest')(labels_input) + + # cropping + if crop_shape != labels_shape: + labels = layers.RandomCrop(crop_shape)(labels) + + # flipping + if flipping: + assert aff is not None, 'aff should not be None if flipping is True' + labels = layers.RandomFlip(get_ras_axes(aff, n_dims)[0], True, generation_labels, n_neutral_labels)(labels) + + # build synthetic image + image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + + # apply bias field + if bias_field_std > 0: + image = layers.BiasFieldCorruption(bias_field_std, bias_scale, False)(image) + + # intensity augmentation + image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.5, separate_channels=True)(image) + + # loop over channels + channels = list() + split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) if (n_channels > 1) else [image] + for i, channel in enumerate(split): + + if randomise_res: + max_res_iso = np.array(utils.reformat_to_list(max_res_iso, length=n_dims, dtype='float')) + max_res_aniso = np.array(utils.reformat_to_list(max_res_aniso, length=n_dims, dtype='float')) + max_res = np.maximum(max_res_iso, max_res_aniso) + resolution, blur_res = layers.SampleResolution(atlas_res, max_res_iso, max_res_aniso)(means_input) + sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, resolution, thickness=blur_res) + channel = layers.DynamicGaussianBlur(0.75 * max_res / np.array(atlas_res), 1.03)([channel, sigma]) + channel = layers.MimicAcquisition(atlas_res, atlas_res, output_shape, False)([channel, resolution]) + channels.append(channel) + + else: + sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, data_res[i], thickness=thickness[i]) + channel = layers.GaussianBlur(sigma, 1.03)(channel) + resolution = KL.Lambda(lambda x: tf.convert_to_tensor(data_res[i], dtype='float32'))([]) + channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)([channel, resolution]) + channels.append(channel) + + # concatenate all channels back + image = KL.Lambda(lambda x: tf.concat(x, -1))(channels) if len(channels) > 1 else channels[0] + + # compute image gradient + if return_gradients: + image = layers.ImageGradients('sobel', True, name='image_gradients')(image) + image = layers.IntensityAugmentation(clip=10, normalise=True)(image) + + # resample labels at target resolution + if crop_shape != output_shape: + labels = l2i_et.resample_tensor(labels, output_shape, interp_method='nearest') + + # map generation labels to segmentation values + labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + + # build model (dummy layer enables to keep the labels when plugging this model to other models) + image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) + brain_model = Model(inputs=list_inputs, outputs=[image, labels]) + + return brain_model + + +def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n): + + # reformat resolutions to lists + atlas_res = utils.reformat_to_list(atlas_res) + n_dims = len(atlas_res) + target_res = utils.reformat_to_list(target_res) + + # get resampling factor + if atlas_res != target_res: + resample_factor = [atlas_res[i] / float(target_res[i]) for i in range(n_dims)] + else: + resample_factor = None + + # output shape specified, need to get cropping shape, and resample shape if necessary + if output_shape is not None: + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') + + # make sure that output shape is smaller or equal to label shape + if resample_factor is not None: + output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] + else: + output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] + + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in output_shape] + if output_shape != tmp_shape: + print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, + tmp_shape)) + output_shape = tmp_shape + + # get cropping and resample shape + if resample_factor is not None: + cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] + else: + cropping_shape = output_shape + + # no output shape specified, so no cropping unless label_shape is not divisible by output_div_by_n + else: + + # make sure output shape is divisible by output_div_by_n + if output_div_by_n is not None: + + # if resampling, get the potential output_shape and check if it is divisible by n + if resample_factor is not None: + output_shape = [int(labels_shape[i] * resample_factor[i]) for i in range(n_dims)] + output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in output_shape] + cropping_shape = [int(np.around(output_shape[i] / resample_factor[i], 0)) for i in range(n_dims)] + # if no resampling, simply check if image_shape is divisible by n + else: + cropping_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in labels_shape] + output_shape = cropping_shape + + # if no need to be divisible by n, simply take cropping_shape as image_shape, and build output_shape + else: + cropping_shape = labels_shape + if resample_factor is not None: + output_shape = [int(cropping_shape[i] * resample_factor[i]) for i in range(n_dims)] + else: + output_shape = cropping_shape + + return cropping_shape, output_shape diff --git a/nobrainer/processing/brain_generator.py b/nobrainer/processing/brain_generator.py new file mode 100644 index 00000000..2b7bfee9 --- /dev/null +++ b/nobrainer/processing/brain_generator.py @@ -0,0 +1,335 @@ +""" +If you use this code, please cite one of the SynthSeg papers: +https://github.com/BBillot/SynthSeg/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +# python imports +import numpy as np + +# project imports +from nobrainer.ext.SynthSeg.model_inputs import build_model_inputs +from nobrainer.models.labels_to_image_model import labels_to_image_model + +# third-party imports +from nobrainer.ext.lab2im import utils, edit_volumes + + +class BrainGenerator: + + def __init__(self, + labels_dir, + generation_labels=None, + n_neutral_labels=None, + output_labels=None, + subjects_prob=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + prior_distributions='uniform', + generation_classes=None, + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, + flipping=True, + scaling_bounds=.2, + rotation_bounds=15, + shearing_bounds=.012, + translation_bounds=False, + nonlin_std=4., + nonlin_scale=.04, + randomise_res=True, + max_res_iso=4., + max_res_aniso=8., + data_res=None, + thickness=None, + bias_field_std=.7, + bias_scale=.025, + return_gradients=False): + """ + This class is wrapper around the labels_to_image_model model. It contains the GPU model that generates images + from labels maps, and a python generator that supplies the input data for this model. + To generate pairs of image/labels you can just call the method generate_image() on an object of this class. + + :param labels_dir: path of folder with all input label maps, or to a single label map. + + # IMPORTANT !!! + # Each time we provide a parameter with separate values for each axis (e.g. with a numpy array or a sequence), + # these values refer to the RAS axes. + + # label maps-related parameters + :param generation_labels: (optional) list of all possible label values in the input label maps. + Default is None, where the label values are directly gotten from the provided label maps. + If not None, can be a sequence or a 1d numpy array, or the path to a 1d numpy array. + If flipping is true (i.e. right/left flipping is enabled), generation_labels should be organised as follows: + background label first, then non-sided labels (e.g. CSF, brainstem, etc.), then all the structures of the same + hemisphere (can be left or right), and finally all the corresponding contralateral structures in the same order. + :param n_neutral_labels: (optional) number of non-sided generation labels. This is important only if you use + flipping augmentation. Default is total number of label values. + :param output_labels: (optional) list of the same length as generation_labels to indicate which values to use in + the label maps returned by this function, i.e. all occurrences of generation_labels[i] in the input label maps + will be converted to output_labels[i] in the returned label maps. Examples: + Set output_labels[i] to zero if you wish to erase the value generation_labels[i] from the returned label maps. + Set output_labels[i]=generation_labels[i] to keep the value generation_labels[i] in the returned maps. + Can be a list or a 1d numpy array. By default output_labels is equal to generation_labels. + :param subjects_prob: (optional) relative order of importance (doesn't have to be probabilistic), with which to + pick the provided label maps at each minibatch. Can be a sequence, a 1D numpy array, or the path to such an + array, and it must be as long as path_label_maps. By default, all label maps are chosen with the same importance + + # output-related parameters + :param batchsize: (optional) numbers of images to generate per mini-batch. Default is 1. + :param n_channels: (optional) number of channels to be synthesised. Default is 1. + :param target_res: (optional) target resolution of the generated images and corresponding label maps. + If None, the outputs will have the same resolution as the input label maps. + Can be a number (isotropic resolution), a sequence, a 1d numpy array, or the path to a 1d numpy array. + :param output_shape: (optional) shape of the output image, obtained by randomly cropping the generated image. + Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or the path to a 1d numpy array. + Default is None, where no cropping is performed. + :param output_div_by_n: (optional) forces the output shape to be divisible by this value. It overwrites + output_shape if necessary. Can be an integer (same size in all dimensions), a sequence, a 1d numpy array, or + the path to a 1d numpy array. + + # GMM-sampling parameters + :param generation_classes: (optional) Indices regrouping generation labels into classes of same intensity + distribution. Regrouped labels will thus share the same Gaussian when sampling a new image. Can be a sequence, a + 1d numpy array, or the path to a 1d numpy array. It should have the same length as generation_labels, and + contain values between 0 and K-1, where K is the total number of classes. + Default is all labels have different classes (K=len(generation_labels)). + :param prior_distributions: (optional) type of distribution from which we sample the GMM parameters. + Can either be 'uniform', or 'normal'. Default is 'uniform'. + :param prior_means: (optional) hyperparameters controlling the prior distributions of the GMM means. Because + these prior distributions are uniform or normal, they require by 2 hyperparameters. Thus prior_means can be: + 1) a sequence of length 2, directly defining the two hyperparameters: [min, max] if prior_distributions is + uniform, [mean, std] if the distribution is normal. The GMM means of are independently sampled at each + mini_batch from the same distribution. + 2) an array of shape (2, K), where K is the number of classes (K=len(generation_labels) if generation_classes is + not given). The mean of the Gaussian distribution associated to class k in [0, ...K-1] is sampled at each + mini-batch from U(prior_means[0,k], prior_means[1,k]) if prior_distributions is uniform, and from + N(prior_means[0,k], prior_means[1,k]) if prior_distributions is normal. + 3) an array of shape (2*n_mod, K), where each block of two rows is associated to hyperparameters derived + from different modalities. In this case, if use_specific_stats_for_channel is False, we first randomly select a + modality from the n_mod possibilities, and we sample the GMM means like in 2). + If use_specific_stats_for_channel is True, each block of two rows correspond to a different channel + (n_mod=n_channels), thus we select the corresponding block to each channel rather than randomly drawing it. + 4) the path to such a numpy array. + Default is None, which corresponds to prior_means = [25, 225]. + :param prior_stds: (optional) same as prior_means but for the standard deviations of the GMM. + Default is None, which corresponds to prior_stds = [5, 25]. + :param use_specific_stats_for_channel: (optional) whether the i-th block of two rows in the prior arrays must be + only used to generate the i-th channel. If True, n_mod should be equal to n_channels. Default is False. + :param mix_prior_and_random: (optional) if prior_means is not None, enables to reset the priors to their default + values for half of these cases, and thus generate images of random contrast. + + # spatial deformation parameters + :param flipping: (optional) whether to introduce right/left random flipping. Default is True. + :param scaling_bounds: (optional) range of the random sampling to apply at each mini-batch. The scaling factor + for each dimension is sampled from a uniform distribution of predefined bounds. Can either be: + 1) a number, in which case the scaling factor is independently sampled from the uniform distribution of bounds + [1-scaling_bounds, 1+scaling_bounds] for each dimension. + 2) a sequence, in which case the scaling factor is sampled from the uniform distribution of bounds + (1-scaling_bounds[i], 1+scaling_bounds[i]) for the i-th dimension. + 3) a numpy array of shape (2, n_dims), in which case the scaling factor is sampled from the uniform distribution + of bounds (scaling_bounds[0, i], scaling_bounds[1, i]) for the i-th dimension. + 4) False, in which case scaling is completely turned off. + Default is scaling_bounds = 0.2 (case 1) + :param rotation_bounds: (optional) same as scaling bounds but for the rotation angle, except that for cases 1 + and 2, the bounds are centred on 0 rather than 1, i.e. [0+rotation_bounds[i], 0-rotation_bounds[i]]. + Default is rotation_bounds = 15. + :param shearing_bounds: (optional) same as scaling bounds. Default is shearing_bounds = 0.012. + :param translation_bounds: (optional) same as scaling bounds. Default is translation_bounds = False, but we + encourage using it when cropping is deactivated (i.e. when output_shape=None in BrainGenerator). + :param nonlin_std: (optional) Maximum value for the standard deviation of the normal distribution from which we + sample the first tensor for synthesising the deformation field. Set to 0 if you wish to completely turn the + elastic deformation off. + :param nonlin_scale: (optional) if nonlin_std is strictly positive, factor between the shapes of the + input label maps and the shape of the input non-linear tensor. + + # blurring/resampling parameters + :param randomise_res: (optional) whether to mimic images that would have been 1) acquired at low resolution, and + 2) resampled to high resolution. The low resolution is uniformly resampled at each minibatch from [1mm, 9mm]. + In that process, the images generated by sampling the GMM are: + 1) blurred at the sampled LR, 2) downsampled at LR, and 3) resampled at target_resolution. + :param max_res_iso: (optional) If randomise_res is True, this enables to control the upper bound of the uniform + distribution from which we sample the random resolution U(min_res, max_res_iso), where min_res is the resolution + of the input label maps. Must be a number, and default is 4. Set to None to deactivate it, but if randomise_res + is True, at least one of max_res_iso or max_res_aniso must be given. + :param max_res_aniso: If randomise_res is True, this enables to downsample the input volumes to a random LR + in only 1 (random) direction. This is done by randomly selecting a direction i in range [0, n_dims-1], and + sampling a value in the corresponding uniform distribution U(min_res[i], max_res_aniso[i]), where min_res is the + resolution of the input label maps. Can be a number, a sequence, or a 1d numpy array. Set to None to deactivate + it, but if randomise_res is True, at least one of max_res_iso or max_res_aniso must be given. + :param data_res: (optional) specific acquisition resolution to mimic, as opposed to random resolution sampled + when randomise_res is True. This triggers a blurring which mimics the acquisition resolution, but downsampling + is optional (see param downsample). Default for data_res is None, where images are slightly blurred. + If the generated images are uni-modal, data_res can be a number (isotropic acquisition resolution), a sequence, + a 1d numpy array, or the path to a 1d numpy array. In the multi-modal case, it should be given as a numpy array + (or a path) of size (n_mod, n_dims), where each row is the acquisition resolution of the corresponding channel. + :param thickness: (optional) if data_res is provided, we can further specify the slice thickness of the low + resolution images to mimic. Must be provided in the same format as data_res. Default thickness = data_res. + + # bias field parameters + :param bias_field_std: (optional) If strictly positive, this triggers the corruption of synthesised images with + a bias field. It is obtained by sampling a first small tensor from a normal distribution, resizing it to full + size, and rescaling it to positive values by taking the voxel-wise exponential. bias_field_std designates the + std dev of the normal distribution from which we sample the first tensor. Set to 0 to deactivate bias field. + :param bias_scale: (optional) If bias_field_std is strictly positive, this designates the ratio between + the size of the input label maps and the size of the first sampled tensor for synthesising the bias field. + + :param return_gradients: (optional) whether to return the synthetic image or the magnitude of its spatial + gradient (computed with Sobel kernels). + """ + + # prepare data files + self.labels_paths = utils.list_images_in_folder(labels_dir) + if subjects_prob is not None: + self.subjects_prob = np.array(utils.reformat_to_list(subjects_prob, load_as_numpy=True), dtype='float32') + assert len(self.subjects_prob) == len(self.labels_paths), \ + 'subjects_prob should have the same length as labels_path, ' \ + 'had {} and {}'.format(len(self.subjects_prob), len(self.labels_paths)) + else: + self.subjects_prob = None + + # generation parameters + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ + utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + self.n_channels = n_channels + if generation_labels is not None: + self.generation_labels = utils.load_array_if_path(generation_labels) + else: + self.generation_labels, _ = utils.get_list_labels(labels_dir=labels_dir) + if output_labels is not None: + self.output_labels = utils.load_array_if_path(output_labels) + else: + self.output_labels = self.generation_labels + if n_neutral_labels is not None: + self.n_neutral_labels = n_neutral_labels + else: + self.n_neutral_labels = self.generation_labels.shape[0] + self.target_res = utils.load_array_if_path(target_res) + self.batchsize = batchsize + # preliminary operations + self.flipping = flipping + self.output_shape = utils.load_array_if_path(output_shape) + self.output_div_by_n = output_div_by_n + # GMM parameters + self.prior_distributions = prior_distributions + if generation_classes is not None: + self.generation_classes = utils.load_array_if_path(generation_classes) + assert self.generation_classes.shape == self.generation_labels.shape, \ + 'if provided, generation_classes should have the same shape as generation_labels' + unique_classes = np.unique(self.generation_classes) + assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ + 'generation_classes should a linear range between 0 and its maximum value.' + else: + self.generation_classes = np.arange(self.generation_labels.shape[0]) + self.prior_means = utils.load_array_if_path(prior_means) + self.prior_stds = utils.load_array_if_path(prior_stds) + self.use_specific_stats_for_channel = use_specific_stats_for_channel + # linear transformation parameters + self.scaling_bounds = utils.load_array_if_path(scaling_bounds) + self.rotation_bounds = utils.load_array_if_path(rotation_bounds) + self.shearing_bounds = utils.load_array_if_path(shearing_bounds) + self.translation_bounds = utils.load_array_if_path(translation_bounds) + # elastic transformation parameters + self.nonlin_std = nonlin_std + self.nonlin_scale = nonlin_scale + # blurring parameters + self.randomise_res = randomise_res + self.max_res_iso = max_res_iso + self.max_res_aniso = max_res_aniso + self.data_res = utils.load_array_if_path(data_res) + assert not (self.randomise_res & (self.data_res is not None)), \ + 'randomise_res and data_res cannot be provided at the same time' + self.thickness = utils.load_array_if_path(thickness) + # bias field parameters + self.bias_field_std = bias_field_std + self.bias_scale = bias_scale + self.return_gradients = return_gradients + + # build transformation model + self.labels_to_image_model, self.model_output_shape = self._build_labels_to_image_model() + + # build generator for model inputs + self.model_inputs_generator = self._build_model_inputs_generator(mix_prior_and_random) + + # build brain generator + self.brain_generator = self._build_brain_generator() + + def _build_labels_to_image_model(self): + # build_model + lab_to_im_model = labels_to_image_model(labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + n_neutral_labels=self.n_neutral_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + flipping=self.flipping, + aff=np.eye(4), + scaling_bounds=self.scaling_bounds, + rotation_bounds=self.rotation_bounds, + shearing_bounds=self.shearing_bounds, + translation_bounds=self.translation_bounds, + nonlin_std=self.nonlin_std, + nonlin_scale=self.nonlin_scale, + randomise_res=self.randomise_res, + max_res_iso=self.max_res_iso, + max_res_aniso=self.max_res_aniso, + data_res=self.data_res, + thickness=self.thickness, + bias_field_std=self.bias_field_std, + bias_scale=self.bias_scale, + return_gradients=self.return_gradients) + out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] + return lab_to_im_model, out_shape + + def _build_model_inputs_generator(self, mix_prior_and_random): + # build model's inputs generator + model_inputs_generator = build_model_inputs(path_label_maps=self.labels_paths, + n_labels=len(self.generation_labels), + batchsize=self.batchsize, + n_channels=self.n_channels, + subjects_prob=self.subjects_prob, + generation_classes=self.generation_classes, + prior_means=self.prior_means, + prior_stds=self.prior_stds, + prior_distributions=self.prior_distributions, + use_specific_stats_for_channel=self.use_specific_stats_for_channel, + mix_prior_and_random=mix_prior_and_random) + return model_inputs_generator + + def _build_brain_generator(self): + while True: + model_inputs = next(self.model_inputs_generator) + [image, labels] = self.labels_to_image_model.predict(model_inputs) + yield image, labels + + def generate_brain(self): + """call this method when an object of this class has been instantiated to generate new brains""" + (image, labels) = next(self.brain_generator) + # put back images in native space + list_images = list() + list_labels = list() + for i in range(self.batchsize): + list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), + aff_ref=self.aff, n_dims=self.n_dims)) + list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), + aff_ref=self.aff, n_dims=self.n_dims)) + image = np.squeeze(np.stack(list_images, axis=0)) + labels = np.squeeze(np.stack(list_labels, axis=0)) + return image, labels From 14d11ca4eeb69869d78960153241923792411585 Mon Sep 17 00:00:00 2001 From: Harsha Date: Wed, 19 Jun 2024 17:25:55 -0400 Subject: [PATCH 13/14] resolved https://github.com/neuronets/nobrainer/issues/339 --- nobrainer/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index f2d6e8bb..fe57d357 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -147,7 +147,7 @@ def from_tfrecords( label_mapping=label_mapping, num_parallel_calls=num_parallel_calls ) - ds_obj.filter_zero_volumes() + # ds_obj.filter_zero_volumes() # TODO automatically determine batch size ds_obj.batch(1) From af67744d8a678bee48a352a71591c4fe4aefbb32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 14 Jul 2024 22:50:52 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nobrainer/ext/SynthSeg/model_inputs.py | 95 +- nobrainer/ext/lab2im/__init__.py | 7 +- nobrainer/ext/lab2im/edit_tensors.py | 197 ++- nobrainer/ext/lab2im/edit_volumes.py | 1837 +++++++++++++++------ nobrainer/ext/lab2im/image_generator.py | 153 +- nobrainer/ext/lab2im/lab2im_model.py | 126 +- nobrainer/ext/lab2im/layers.py | 1239 ++++++++++---- nobrainer/ext/lab2im/utils.py | 864 +++++++--- nobrainer/ext/neuron/__init__.py | 4 +- nobrainer/ext/neuron/layers.py | 219 ++- nobrainer/ext/neuron/models.py | 753 +++++---- nobrainer/ext/neuron/utils.py | 207 ++- nobrainer/models/__init__.py | 4 +- nobrainer/models/labels_to_image_model.py | 232 ++- nobrainer/processing/brain_generator.py | 199 ++- nobrainer/processing/segmentation.py | 4 +- nobrainer/tfrecord.py | 8 +- 17 files changed, 4189 insertions(+), 1959 deletions(-) diff --git a/nobrainer/ext/SynthSeg/model_inputs.py b/nobrainer/ext/SynthSeg/model_inputs.py index 34f4b4ad..e6d9cda2 100644 --- a/nobrainer/ext/SynthSeg/model_inputs.py +++ b/nobrainer/ext/SynthSeg/model_inputs.py @@ -13,7 +13,6 @@ License. """ - # python imports import numpy as np import numpy.random as npr @@ -22,17 +21,19 @@ from nobrainer.ext.lab2im import utils -def build_model_inputs(path_label_maps, - n_labels, - batchsize=1, - n_channels=1, - subjects_prob=None, - generation_classes=None, - prior_distributions='uniform', - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - mix_prior_and_random=False): +def build_model_inputs( + path_label_maps, + n_labels, + batchsize=1, + n_channels=1, + subjects_prob=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, +): """ This function builds a generator that will be used to give the necessary inputs to the label_to_image model: the input label maps, as well as the means and stds defining the parameters of the GMM (which change at each minibatch). @@ -86,7 +87,9 @@ def build_model_inputs(path_label_maps, while True: # randomly pick as many images as batchsize - indices = npr.choice(np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob) + indices = npr.choice( + np.arange(len(path_label_maps)), size=batchsize, p=subjects_prob + ) # initialise input lists list_label_maps = [] @@ -96,8 +99,10 @@ def build_model_inputs(path_label_maps, for idx in indices: # load input label map - lab = utils.load_volume(path_label_maps[idx], dtype='int', aff_ref=np.eye(4)) - if (npr.uniform() > 0.7) & ('seg_cerebral' in path_label_maps[idx]): + lab = utils.load_volume( + path_label_maps[idx], dtype="int", aff_ref=np.eye(4) + ) + if (npr.uniform() > 0.7) & ("seg_cerebral" in path_label_maps[idx]): lab[lab == 24] = 0 # add label map to inputs @@ -112,42 +117,72 @@ def build_model_inputs(path_label_maps, if isinstance(prior_means, np.ndarray): if (prior_means.shape[0] > 2) & use_specific_stats_for_channel: if prior_means.shape[0] / 2 != n_channels: - raise ValueError("the number of blocks in prior_means does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_means = prior_means[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_means = prior_means[2 * channel : 2 * channel + 2, :] else: tmp_prior_means = prior_means else: tmp_prior_means = prior_means - if (prior_means is not None) & mix_prior_and_random & (npr.uniform() > 0.5): + if ( + (prior_means is not None) + & mix_prior_and_random + & (npr.uniform() > 0.5) + ): tmp_prior_means = None if isinstance(prior_stds, np.ndarray): if (prior_stds.shape[0] > 2) & use_specific_stats_for_channel: if prior_stds.shape[0] / 2 != n_channels: - raise ValueError("the number of blocks in prior_stds does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_stds = prior_stds[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_stds = prior_stds[2 * channel : 2 * channel + 2, :] else: tmp_prior_stds = prior_stds else: tmp_prior_stds = prior_stds - if (prior_stds is not None) & mix_prior_and_random & (npr.uniform() > 0.5): + if ( + (prior_stds is not None) + & mix_prior_and_random + & (npr.uniform() > 0.5) + ): tmp_prior_stds = None # draw means and std devs from priors - tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_classes, prior_distributions, - 125., 125., positive_only=True) - tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_classes, prior_distributions, - 15., 15., positive_only=True) + tmp_classes_means = utils.draw_value_from_distribution( + tmp_prior_means, + n_classes, + prior_distributions, + 125.0, + 125.0, + positive_only=True, + ) + tmp_classes_stds = utils.draw_value_from_distribution( + tmp_prior_stds, + n_classes, + prior_distributions, + 15.0, + 15.0, + positive_only=True, + ) random_coef = npr.uniform() if random_coef > 0.95: # reset the background to 0 in 5% of cases tmp_classes_means[0] = 0 tmp_classes_stds[0] = 0 - elif random_coef > 0.7: # reset the background to low Gaussian in 25% of cases + elif ( + random_coef > 0.7 + ): # reset the background to low Gaussian in 25% of cases tmp_classes_means[0] = npr.uniform(0, 15) tmp_classes_stds[0] = npr.uniform(0, 5) - tmp_means = utils.add_axis(tmp_classes_means[generation_classes], axis=[0, -1]) - tmp_stds = utils.add_axis(tmp_classes_stds[generation_classes], axis=[0, -1]) + tmp_means = utils.add_axis( + tmp_classes_means[generation_classes], axis=[0, -1] + ) + tmp_stds = utils.add_axis( + tmp_classes_stds[generation_classes], axis=[0, -1] + ) means = np.concatenate([means, tmp_means], axis=-1) stds = np.concatenate([stds, tmp_stds], axis=-1) list_means.append(means) diff --git a/nobrainer/ext/lab2im/__init__.py b/nobrainer/ext/lab2im/__init__.py index d3fd52d0..f26d7db9 100644 --- a/nobrainer/ext/lab2im/__init__.py +++ b/nobrainer/ext/lab2im/__init__.py @@ -1,6 +1 @@ -from . import edit_tensors -from . import edit_volumes -from . import image_generator -from . import lab2im_model -from . import layers -from . import utils +from . import edit_tensors, edit_volumes, image_generator, lab2im_model, layers, utils diff --git a/nobrainer/ext/lab2im/edit_tensors.py b/nobrainer/ext/lab2im/edit_tensors.py index 65e72a02..c34d0374 100644 --- a/nobrainer/ext/lab2im/edit_tensors.py +++ b/nobrainer/ext/lab2im/edit_tensors.py @@ -22,13 +22,14 @@ """ +from itertools import combinations + +import keras.backend as K +import keras.layers as KL # python imports import numpy as np import tensorflow as tf -import keras.layers as KL -import keras.backend as K -from itertools import combinations # project imports from nobrainer.ext.lab2im import utils @@ -38,7 +39,9 @@ from nobrainer.ext.neuron.utils import volshape_to_meshgrid -def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, thickness=None): +def blurring_sigma_for_downsampling( + current_res, downsample_res, mult_coef=None, thickness=None +): """Compute standard deviations of 1d gaussian masks for image blurring before downsampling. :param downsample_res: resolution to downsample to. Can be a 1d numpy array or list, or a tensor. :param current_res: resolution of the volume before downsampling. @@ -68,17 +71,32 @@ def blurring_sigma_for_downsampling(current_res, downsample_res, mult_coef=None, # reformat data resolution at which we blur if thickness is not None: - down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))([downsample_res, thickness]) + down_res = KL.Lambda(lambda x: tf.math.minimum(x[0], x[1]))( + [downsample_res, thickness] + ) else: down_res = downsample_res # get std deviation for blurring kernels if mult_coef is None: - sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x, tf.convert_to_tensor(current_res, dtype='float32')), - 0.5, 0.75 * x / tf.convert_to_tensor(current_res, dtype='float32')))(down_res) + sigma = KL.Lambda( + lambda x: tf.where( + tf.math.equal( + x, tf.convert_to_tensor(current_res, dtype="float32") + ), + 0.5, + 0.75 * x / tf.convert_to_tensor(current_res, dtype="float32"), + ) + )(down_res) else: - sigma = KL.Lambda(lambda x: mult_coef * x / tf.convert_to_tensor(current_res, dtype='float32'))(down_res) - sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.), 0., x[1]))([down_res, sigma]) + sigma = KL.Lambda( + lambda x: mult_coef + * x + / tf.convert_to_tensor(current_res, dtype="float32") + )(down_res) + sigma = KL.Lambda(lambda x: tf.where(tf.math.equal(x[0], 0.0), 0.0, x[1]))( + [down_res, sigma] + ) return sigma @@ -95,9 +113,13 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): """ # convert sigma into a tensor if not tf.is_tensor(sigma): - sigma_tens = tf.convert_to_tensor(utils.reformat_to_list(sigma), dtype='float32') + sigma_tens = tf.convert_to_tensor( + utils.reformat_to_list(sigma), dtype="float32" + ) else: - assert max_sigma is not None, 'max_sigma must be provided when sigma is given as a tensor' + assert ( + max_sigma is not None + ), "max_sigma must be provided when sigma is given as a tensor" sigma_tens = sigma shape = sigma_tens.get_shape().as_list() @@ -118,7 +140,9 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): # randomise the burring std dev and/or split it between dimensions if blur_range is not None: if blur_range != 1: - sigma_tens = sigma_tens * tf.random.uniform(tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range) + sigma_tens = sigma_tens * tf.random.uniform( + tf.shape(sigma_tens), minval=1 / blur_range, maxval=blur_range + ) # get size of blurring kernels windowsize = np.int32(np.ceil(2.5 * max_sigma) / 2) * 2 + 1 @@ -129,16 +153,23 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): kernels = list() comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) - for (i, wsize) in enumerate(windowsize): + for i, wsize in enumerate(windowsize): if wsize > 1: # build meshgrid and replicate it along batch dim if dynamic blurring - locations = tf.cast(tf.range(0, wsize), 'float32') - (wsize - 1) / 2 + locations = tf.cast(tf.range(0, wsize), "float32") - (wsize - 1) / 2 if batchsize is not None: - locations = tf.tile(tf.expand_dims(locations, axis=0), - tf.concat([batchsize, tf.ones(tf.shape(tf.shape(locations)), dtype='int32')], - axis=0)) + locations = tf.tile( + tf.expand_dims(locations, axis=0), + tf.concat( + [ + batchsize, + tf.ones(tf.shape(tf.shape(locations)), dtype="int32"), + ], + axis=0, + ), + ) comb[i] += 1 # compute gaussians @@ -156,13 +187,23 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): else: # build meshgrid - mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')] - diff = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1) + mesh = [ + tf.cast(f, "float32") + for f in volshape_to_meshgrid(windowsize, indexing="ij") + ] + diff = tf.stack( + [mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1 + ) # replicate meshgrid to batch size and reshape sigma_tens if batchsize is not None: - diff = tf.tile(tf.expand_dims(diff, axis=0), - tf.concat([batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype='int32')], axis=0)) + diff = tf.tile( + tf.expand_dims(diff, axis=0), + tf.concat( + [batchsize, tf.ones(tf.shape(tf.shape(diff)), dtype="int32")], + axis=0, + ), + ) for i in range(n_dims): sigma_tens = tf.expand_dims(sigma_tens, axis=1) else: @@ -171,8 +212,14 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): # compute gaussians sigma_is_0 = tf.equal(sigma_tens, 0) - exp_term = -K.square(diff) / (2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens)**2) - norms = exp_term - tf.math.log(tf.where(sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens)) + exp_term = -K.square(diff) / ( + 2 * tf.where(sigma_is_0, tf.ones_like(sigma_tens), sigma_tens) ** 2 + ) + norms = exp_term - tf.math.log( + tf.where( + sigma_is_0, tf.ones_like(sigma_tens), np.sqrt(2 * np.pi) * sigma_tens + ) + ) kernels = K.sum(norms, -1) kernels = tf.exp(kernels) kernels /= tf.reduce_sum(kernels) @@ -184,8 +231,8 @@ def gaussian_kernel(sigma, max_sigma=None, blur_range=None, separable=True): def sobel_kernels(n_dims): """Returns sobel kernels to compute spatial derivative on image of n dimensions.""" - in_dir = tf.convert_to_tensor([1, 0, -1], dtype='float32') - orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype='float32') + in_dir = tf.convert_to_tensor([1, 0, -1], dtype="float32") + orthogonal_dir = tf.convert_to_tensor([1, 2, 1], dtype="float32") comb = np.array(list(combinations(list(range(n_dims)), n_dims - 1))[::-1]) list_kernels = list() @@ -216,31 +263,49 @@ def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None): # convert dist_threshold into a tensor if not tf.is_tensor(dist_threshold): - dist_threshold_tens = tf.convert_to_tensor(utils.reformat_to_list(dist_threshold), dtype='float32') + dist_threshold_tens = tf.convert_to_tensor( + utils.reformat_to_list(dist_threshold), dtype="float32" + ) else: - assert max_dist_threshold is not None, 'max_sigma must be provided when dist_threshold is given as a tensor' - dist_threshold_tens = tf.cast(dist_threshold, 'float32') + assert ( + max_dist_threshold is not None + ), "max_sigma must be provided when dist_threshold is given as a tensor" + dist_threshold_tens = tf.cast(dist_threshold, "float32") shape = dist_threshold_tens.get_shape().as_list() # get batchsize - batchsize = None if shape[0] is not None else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0] + batchsize = ( + None + if shape[0] is not None + else tf.split(tf.shape(dist_threshold_tens), [1, -1])[0] + ) # set max_dist_threshold into an array - if max_dist_threshold is None: # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch) + if ( + max_dist_threshold is None + ): # dist_threshold is fixed (i.e. dist_threshold will not change at each mini-batch) max_dist_threshold = dist_threshold # get size of blurring kernels - windowsize = np.array([max_dist_threshold * 2 + 1]*n_dims, dtype='int32') + windowsize = np.array([max_dist_threshold * 2 + 1] * n_dims, dtype="int32") # build tensor representing the distance from the centre - mesh = [tf.cast(f, 'float32') for f in volshape_to_meshgrid(windowsize, indexing='ij')] - dist = tf.stack([mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1) + mesh = [ + tf.cast(f, "float32") for f in volshape_to_meshgrid(windowsize, indexing="ij") + ] + dist = tf.stack( + [mesh[f] - (windowsize[f] - 1) / 2 for f in range(len(windowsize))], axis=-1 + ) dist = tf.sqrt(tf.reduce_sum(tf.square(dist), axis=-1)) # replicate distance to batch size and reshape sigma_tens if batchsize is not None: - dist = tf.tile(tf.expand_dims(dist, axis=0), - tf.concat([batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype='int32')], axis=0)) + dist = tf.tile( + tf.expand_dims(dist, axis=0), + tf.concat( + [batchsize, tf.ones(tf.shape(tf.shape(dist)), dtype="int32")], axis=0 + ), + ) for i in range(n_dims - 1): dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=1) else: @@ -248,18 +313,24 @@ def unit_kernel(dist_threshold, n_dims, max_dist_threshold=None): dist_threshold_tens = tf.expand_dims(dist_threshold_tens, axis=0) # build final kernel by thresholding distance tensor - kernel = tf.where(tf.less_equal(dist, dist_threshold_tens), tf.ones_like(dist), tf.zeros_like(dist)) + kernel = tf.where( + tf.less_equal(dist, dist_threshold_tens), + tf.ones_like(dist), + tf.zeros_like(dist), + ) kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) return kernel -def resample_tensor(tensor, - resample_shape, - interp_method='linear', - subsample_res=None, - volume_res=None, - build_reliability_map=False): +def resample_tensor( + tensor, + resample_shape, + interp_method="linear", + subsample_res=None, + volume_res=None, + build_reliability_map=False, +): """This function resamples a volume to resample_shape. It does not apply any pre-filtering. A prior downsampling step can be added if subsample_res is specified. In this case, volume_res should also be specified, in order to calculate the downsampling ratio. A reliability map can also be returned to indicate which @@ -286,22 +357,35 @@ def resample_tensor(tensor, downsample_shape = tensor_shape # will be modified if we actually downsample if subsample_res is not None: - assert volume_res is not None, 'volume_res must be given when providing a subsampling resolution.' - assert len(subsample_res) == len(volume_res), 'subsample_res and volume_res must have the same length, ' \ - 'had {0}, and {1}'.format(len(subsample_res), len(volume_res)) + assert ( + volume_res is not None + ), "volume_res must be given when providing a subsampling resolution." + assert len(subsample_res) == len(volume_res), ( + "subsample_res and volume_res must have the same length, " + "had {0}, and {1}".format(len(subsample_res), len(volume_res)) + ) if subsample_res != volume_res: # get shape at which we downsample - downsample_shape = [int(tensor_shape[i] * volume_res[i] / subsample_res[i]) for i in range(n_dims)] + downsample_shape = [ + int(tensor_shape[i] * volume_res[i] / subsample_res[i]) + for i in range(n_dims) + ] # downsample volume tensor._keras_shape = tuple(tensor.get_shape().as_list()) - tensor = nrn_layers.Resize(size=downsample_shape, interp_method='nearest')(tensor) + tensor = nrn_layers.Resize(size=downsample_shape, interp_method="nearest")( + tensor + ) # resample image at target resolution - if resample_shape != downsample_shape: # if we didn't downsample downsample_shape = tensor_shape + if ( + resample_shape != downsample_shape + ): # if we didn't downsample downsample_shape = tensor_shape tensor._keras_shape = tuple(tensor.get_shape().as_list()) - tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)(tensor) + tensor = nrn_layers.Resize(size=resample_shape, interp_method=interp_method)( + tensor + ) # compute reliability maps if necessary and return results if build_reliability_map: @@ -320,13 +404,20 @@ def resample_tensor(tensor, loc_ceil = np.int32(np.clip(loc_floor + 1, 0, resample_shape[i] - 1)) tmp_reliability_map = np.zeros(resample_shape[i]) tmp_reliability_map[loc_floor] = 1 - (loc_float - loc_floor) - tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + (loc_float - loc_floor) + tmp_reliability_map[loc_ceil] = tmp_reliability_map[loc_ceil] + ( + loc_float - loc_floor + ) shape = [1, 1, 1] shape[i] = resample_shape[i] - reliability_map = reliability_map * np.reshape(tmp_reliability_map, shape) + reliability_map = reliability_map * np.reshape( + tmp_reliability_map, shape + ) shape = KL.Lambda(lambda x: tf.shape(x))(tensor) - mask = KL.Lambda(lambda x: tf.reshape(tf.convert_to_tensor(reliability_map, dtype='float32'), - shape=x))(shape) + mask = KL.Lambda( + lambda x: tf.reshape( + tf.convert_to_tensor(reliability_map, dtype="float32"), shape=x + ) + )(shape) # otherwise just return an all-one tensor else: diff --git a/nobrainer/ext/lab2im/edit_volumes.py b/nobrainer/ext/lab2im/edit_volumes.py index c8e388a5..9bed9453 100644 --- a/nobrainer/ext/lab2im/edit_volumes.py +++ b/nobrainer/ext/lab2im/edit_volumes.py @@ -69,31 +69,41 @@ License. """ +import csv # python imports import os -import csv import shutil -import numpy as np -import tensorflow as tf + import keras.layers as KL from keras.models import Model -from scipy.ndimage.filters import convolve -from scipy.ndimage import label as scipy_label +import numpy as np from scipy.interpolate import RegularGridInterpolator -from scipy.ndimage.morphology import distance_transform_edt, binary_fill_holes from scipy.ndimage import binary_dilation, binary_erosion, gaussian_filter +from scipy.ndimage import label as scipy_label +from scipy.ndimage.filters import convolve +from scipy.ndimage.morphology import binary_fill_holes, distance_transform_edt +import tensorflow as tf # project imports from nobrainer.ext.lab2im import utils -from nobrainer.ext.lab2im.layers import GaussianBlur, ConvertLabels from nobrainer.ext.lab2im.edit_tensors import blurring_sigma_for_downsampling - +from nobrainer.ext.lab2im.layers import ConvertLabels, GaussianBlur # ---------------------------------------------------- edit volume ----------------------------------------------------- -def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, masking_value=0, - return_mask=False, return_copy=True): + +def mask_volume( + volume, + mask=None, + threshold=0.1, + dilate=0, + erode=0, + fill_holes=False, + masking_value=0, + return_mask=False, + return_copy=True, +): """Mask a volume, either with a given mask, or by keeping only the values above a threshold. :param volume: a numpy array, possibly with several channels :param mask: (optional) a numpy array to mask volume with. @@ -119,8 +129,11 @@ def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes= if mask is None: mask = new_volume >= threshold else: - assert list(mask.shape[:n_dims]) == vol_shape[:n_dims], 'mask should have shape {0}, or {1}, had {2}'.format( - vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape)) + assert ( + list(mask.shape[:n_dims]) == vol_shape[:n_dims] + ), "mask should have shape {0}, or {1}, had {2}".format( + vol_shape[:n_dims], vol_shape[:n_dims] + [n_channels], list(mask.shape) + ) mask = mask > 0 if dilate > 0: dilate_struct = utils.build_binary_structure(dilate, n_dims) @@ -137,7 +150,9 @@ def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes= if mask_to_apply.shape == new_volume.shape: new_volume[np.logical_not(mask_to_apply)] = masking_value else: - new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = masking_value + new_volume[np.stack([np.logical_not(mask_to_apply)] * n_channels, axis=-1)] = ( + masking_value + ) if return_mask: return new_volume, mask_to_apply @@ -145,7 +160,14 @@ def mask_volume(volume, mask=None, threshold=0.1, dilate=0, erode=0, fill_holes= return new_volume -def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2, max_percentile=98, use_positive_only=False): +def rescale_volume( + volume, + new_min=0, + new_max=255, + min_percentile=2, + max_percentile=98, + use_positive_only=False, +): """This function linearly rescales a volume between new_min and new_max. :param volume: a numpy array :param new_min: (optional) minimum value for the rescaled image. @@ -160,23 +182,42 @@ def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2, max_percent # select only positive intensities new_volume = volume.copy() - intensities = new_volume[new_volume > 0] if use_positive_only else new_volume.flatten() + intensities = ( + new_volume[new_volume > 0] if use_positive_only else new_volume.flatten() + ) # define min and max intensities in original image for normalisation - robust_min = np.min(intensities) if min_percentile == 0 else np.percentile(intensities, min_percentile) - robust_max = np.max(intensities) if max_percentile == 100 else np.percentile(intensities, max_percentile) + robust_min = ( + np.min(intensities) + if min_percentile == 0 + else np.percentile(intensities, min_percentile) + ) + robust_max = ( + np.max(intensities) + if max_percentile == 100 + else np.percentile(intensities, max_percentile) + ) # trim values outside range new_volume = np.clip(new_volume, robust_min, robust_max) # rescale image if robust_min != robust_max: - return new_min + (new_volume - robust_min) / (robust_max - robust_min) * (new_max - new_min) + return new_min + (new_volume - robust_min) / (robust_max - robust_min) * ( + new_max - new_min + ) else: # avoid dividing by zero return np.zeros_like(new_volume) -def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, return_crop_idx=False, mode='center'): +def crop_volume( + volume, + cropping_margin=None, + cropping_shape=None, + aff=None, + return_crop_idx=False, + mode="center", +): """Crop volume by a given margin, or to a given shape. :param volume: 2d or 3d numpy array (possibly with multiple channels) :param cropping_margin: (optional) margin by which to crop the volume. The cropping margin is applied on both sides. @@ -192,10 +233,12 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret True (in that order). """ - assert (cropping_margin is not None) | (cropping_shape is not None), \ - 'cropping_margin or cropping_shape should be provided' - assert not ((cropping_margin is not None) & (cropping_shape is not None)), \ - 'only one of cropping_margin or cropping_shape should be provided' + assert (cropping_margin is not None) | ( + cropping_shape is not None + ), "cropping_margin or cropping_shape should be provided" + assert not ( + (cropping_margin is not None) & (cropping_shape is not None) + ), "only one of cropping_margin or cropping_shape should be provided" # get info new_volume = volume.copy() @@ -206,27 +249,49 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret if cropping_margin is not None: cropping_margin = utils.reformat_to_list(cropping_margin, length=n_dims) do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin) - min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)] - max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)] + min_crop_idx = [ + cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims) + ] + max_crop_idx = [ + vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] + for i in range(n_dims) + ] else: cropping_shape = utils.reformat_to_list(cropping_shape, length=n_dims) - if mode == 'center': - min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0) - max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)], - np.array(vol_shape)[:n_dims]) - elif mode == 'random': - crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0) + if mode == "center": + min_crop_idx = np.maximum( + [int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0 + ) + max_crop_idx = np.minimum( + [min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)], + np.array(vol_shape)[:n_dims], + ) + elif mode == "random": + crop_max_val = np.maximum( + np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0 + ) min_crop_idx = np.random.randint(0, high=crop_max_val + 1) - max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims]) + max_crop_idx = np.minimum( + min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims] + ) else: - raise ValueError('mode should be either "center" or "random", had %s' % mode) + raise ValueError( + 'mode should be either "center" or "random", had %s' % mode + ) crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)]) # crop volume if n_dims == 2: - new_volume = new_volume[crop_idx[0]: crop_idx[2], crop_idx[1]: crop_idx[3], ...] + new_volume = new_volume[ + crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ... + ] elif n_dims == 3: - new_volume = new_volume[crop_idx[0]: crop_idx[3], crop_idx[1]: crop_idx[4], crop_idx[2]: crop_idx[5], ...] + new_volume = new_volume[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] # sort outputs output = [new_volume] @@ -238,15 +303,17 @@ def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, ret return output[0] if len(output) == 1 else tuple(output) -def crop_volume_around_region(volume, - mask=None, - masking_labels=None, - threshold=0.1, - margin=0, - cropping_shape=None, - cropping_shape_div_by=None, - aff=None, - overflow='strict'): +def crop_volume_around_region( + volume, + mask=None, + masking_labels=None, + threshold=0.1, + margin=0, + cropping_shape=None, + cropping_shape_div_by=None, + aff=None, + overflow="strict", +): """Crop a volume around a specific region. This region is defined by a mask obtained by either: 1) directly specifying it as input (see mask) @@ -285,11 +352,15 @@ def crop_volume_around_region(volume, and the updated affine matrix if aff is not None. """ - assert not ((margin > 0) & (cropping_shape is not None)), "margin and cropping_shape can't be given together." - assert not ((margin > 0) & (cropping_shape_div_by is not None)), \ - "margin and cropping_shape_div_by can't be given together." - assert not ((cropping_shape_div_by is not None) & (cropping_shape is not None)), \ - "cropping_shape_div_by and cropping_shape can't be given together." + assert not ( + (margin > 0) & (cropping_shape is not None) + ), "margin and cropping_shape can't be given together." + assert not ( + (margin > 0) & (cropping_shape_div_by is not None) + ), "margin and cropping_shape_div_by can't be given together." + assert not ( + (cropping_shape_div_by is not None) & (cropping_shape is not None) + ), "cropping_shape_div_by and cropping_shape can't be given together." new_vol = volume.copy() n_dims, n_channels = utils.get_dims(new_vol.shape) @@ -298,7 +369,9 @@ def crop_volume_around_region(volume, # mask ROIs for cropping if mask is None: if masking_labels is not None: - _, mask = mask_label_map(new_vol, masking_values=masking_labels, return_mask=True) + _, mask = mask_label_map( + new_vol, masking_values=masking_labels, return_mask=True + ) else: mask = new_vol > threshold @@ -315,25 +388,35 @@ def crop_volume_around_region(volume, if margin: cropping_shape = intermediate_vol_shape + 2 * margin elif cropping_shape is not None: - cropping_shape = np.array(utils.reformat_to_list(cropping_shape, length=n_dims)) + cropping_shape = np.array( + utils.reformat_to_list(cropping_shape, length=n_dims) + ) elif cropping_shape_div_by is not None: - cropping_shape = [utils.find_closest_number_divisible_by_m(s, cropping_shape_div_by, answer_type='higher') - for s in intermediate_vol_shape] - - min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2)) - max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2)) + cropping_shape = [ + utils.find_closest_number_divisible_by_m( + s, cropping_shape_div_by, answer_type="higher" + ) + for s in intermediate_vol_shape + ] + + min_idx = min_idx - np.int32( + np.ceil((cropping_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((cropping_shape - intermediate_vol_shape) / 2) + ) min_overflow = np.abs(np.minimum(min_idx, 0)) max_overflow = np.maximum(max_idx - vol_shape, 0) - if 'strict' in overflow: + if "strict" in overflow: min_overflow = np.zeros_like(min_overflow) max_overflow = np.zeros_like(min_overflow) - if overflow == 'shift-strict': + if overflow == "shift-strict": min_idx -= max_overflow max_idx += min_overflow - if overflow == 'shift-padding': + if overflow == "shift-padding": for ii in range(n_dims): # no need to do anything if both min/max_overflow are 0 (no padding/shifting required at all) # or if both are positive, because in this case we don't shift at all and we pad directly @@ -343,7 +426,9 @@ def crop_volume_around_region(volume, max_idx[ii] = max_idx_new min_overflow[ii] = 0 else: - min_overflow[ii] = min_overflow[ii] - (vol_shape[ii] - max_idx[ii]) + min_overflow[ii] = min_overflow[ii] - ( + vol_shape[ii] - max_idx[ii] + ) max_idx[ii] = vol_shape[ii] elif (min_overflow[ii] == 0) & (max_overflow[ii] > 0): min_idx_new = min_idx[ii] - max_overflow[ii] @@ -360,17 +445,28 @@ def crop_volume_around_region(volume, cropping = np.concatenate([min_idx, max_idx]) if np.any(cropping[:3] > 0) or np.any(cropping[3:] != vol_shape): if n_dims == 3: - new_vol = new_vol[cropping[0]:cropping[3], cropping[1]:cropping[4], cropping[2]:cropping[5], ...] + new_vol = new_vol[ + cropping[0] : cropping[3], + cropping[1] : cropping[4], + cropping[2] : cropping[5], + ..., + ] elif n_dims == 2: - new_vol = new_vol[cropping[0]:cropping[2], cropping[1]:cropping[3], ...] + new_vol = new_vol[ + cropping[0] : cropping[2], cropping[1] : cropping[3], ... + ] else: - raise ValueError('cannot crop volumes with more than 3 dimensions') + raise ValueError("cannot crop volumes with more than 3 dimensions") # pad volume if necessary if np.any(min_overflow > 0) | np.any(max_overflow > 0): - pad_margins = tuple([(min_overflow[i], max_overflow[i]) for i in range(n_dims)]) - pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins - new_vol = np.pad(new_vol, pad_margins, mode='constant', constant_values=0) + pad_margins = tuple( + [(min_overflow[i], max_overflow[i]) for i in range(n_dims)] + ) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins + ) + new_vol = np.pad(new_vol, pad_margins, mode="constant", constant_values=0) # if there's nothing to crop around, we return the input as is else: @@ -408,11 +504,18 @@ def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=Tr # crop image if n_dims == 2: - new_volume = new_volume[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...] + new_volume = new_volume[ + crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ... + ] elif n_dims == 3: - new_volume = new_volume[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...] + new_volume = new_volume[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] else: - raise Exception('cannot crop volumes with more than 3 dimensions') + raise Exception("cannot crop volumes with more than 3 dimensions") if aff is not None: aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3] @@ -436,21 +539,38 @@ def pad_volume(volume, padding_shape, padding_value=0, aff=None, return_pad_idx= new_volume = volume.copy() vol_shape = new_volume.shape n_dims, n_channels = utils.get_dims(vol_shape) - padding_shape = utils.reformat_to_list(padding_shape, length=n_dims, dtype='int') + padding_shape = utils.reformat_to_list(padding_shape, length=n_dims, dtype="int") # check if need to pad - if np.any(np.array(padding_shape, dtype='int32') > np.array(vol_shape[:n_dims], dtype='int32')): + if np.any( + np.array(padding_shape, dtype="int32") + > np.array(vol_shape[:n_dims], dtype="int32") + ): # get padding margins - min_margins = np.maximum(np.int32(np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0) - max_margins = np.maximum(np.int32(np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2)), 0) - pad_idx = np.concatenate([min_margins, min_margins + np.array(vol_shape[:n_dims])]) + min_margins = np.maximum( + np.int32( + np.floor((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2) + ), + 0, + ) + max_margins = np.maximum( + np.int32( + np.ceil((np.array(padding_shape) - np.array(vol_shape)[:n_dims]) / 2) + ), + 0, + ) + pad_idx = np.concatenate( + [min_margins, min_margins + np.array(vol_shape[:n_dims])] + ) pad_margins = tuple([(min_margins[i], max_margins[i]) for i in range(n_dims)]) if n_channels > 1: pad_margins = tuple(list(pad_margins) + [(0, 0)]) # pad volume - new_volume = np.pad(new_volume, pad_margins, mode='constant', constant_values=padding_value) + new_volume = np.pad( + new_volume, pad_margins, mode="constant", constant_values=padding_value + ) if aff is not None: if n_dims == 2: @@ -482,26 +602,29 @@ def flip_volume(volume, axis=None, direction=None, aff=None, return_copy=True): """ new_volume = volume.copy() if return_copy else volume - assert (axis is not None) | ((aff is not None) & (direction is not None)), \ - 'please provide either axis, or an affine matrix with a direction' + assert (axis is not None) | ( + (aff is not None) & (direction is not None) + ), "please provide either axis, or an affine matrix with a direction" # get flipping axis from aff if axis not provided if (axis is None) & (aff is not None): volume_axes = get_ras_axes(aff) - if direction == 'rl': + if direction == "rl": axis = volume_axes[0] - elif direction == 'ap': + elif direction == "ap": axis = volume_axes[1] - elif direction == 'si': + elif direction == "si": axis = volume_axes[2] else: - raise ValueError("direction should be 'rl', 'ap', or 'si', had %s" % direction) + raise ValueError( + "direction should be 'rl', 'ap', or 'si', had %s" % direction + ) # flip volume return np.flip(new_volume, axis=axis) -def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True): +def resample_volume(volume, aff, new_vox_size, interpolation="linear", blur=True): """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS :param volume: a numpy array :param aff: affine matrix of the volume @@ -525,9 +648,11 @@ def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True y = np.arange(0, volume_filt.shape[1]) z = np.arange(0, volume_filt.shape[2]) - my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt, method=interpolation) + my_interpolating_function = RegularGridInterpolator( + (x, y, z), volume_filt, method=interpolation + ) - start = - (factor - 1) / (2 * factor) + start = -(factor - 1) / (2 * factor) step = 1.0 / factor stop = start + step * np.ceil(volume_filt.shape * factor) @@ -541,7 +666,7 @@ def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1 zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1 - xig, yig, zig = np.meshgrid(xi, yi, zi, indexing='ij', sparse=True) + xig, yig, zig = np.meshgrid(xi, yi, zi, indexing="ij", sparse=True) volume2 = my_interpolating_function((xig, yig, zig)) aff2 = aff.copy() @@ -552,7 +677,7 @@ def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True return volume2, aff2 -def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation='linear'): +def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation="linear"): """This function reslices a floating image to the space of a reference image :param vol_ref: a numpy array with the reference volume :param aff_ref: affine matrix of the reference volume @@ -568,14 +693,15 @@ def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation='line yf = np.arange(0, vol_flo.shape[1]) zf = np.arange(0, vol_flo.shape[2]) - my_interpolating_function = RegularGridInterpolator((xf, yf, zf), vol_flo, bounds_error=False, fill_value=0.0, - method=interpolation) + my_interpolating_function = RegularGridInterpolator( + (xf, yf, zf), vol_flo, bounds_error=False, fill_value=0.0, method=interpolation + ) xr = np.arange(0, vol_ref.shape[0]) yr = np.arange(0, vol_ref.shape[1]) zr = np.arange(0, vol_ref.shape[2]) - xrg, yrg, zrg = np.meshgrid(xr, yr, zr, indexing='ij', sparse=False) + xrg, yrg, zrg = np.meshgrid(xr, yr, zr, indexing="ij", sparse=False) n = xrg.size xrg = xrg.reshape([n]) yrg = yrg.reshape([n]) @@ -583,7 +709,9 @@ def resample_volume_like(vol_ref, aff_ref, vol_flo, aff_flo, interpolation='line bottom = np.ones_like(xrg) coords = np.stack([xrg, yrg, zrg, bottom]) coords_new = np.matmul(T, coords)[:-1, :] - result = my_interpolating_function((coords_new[0, :], coords_new[1, :], coords_new[2, :])) + result = my_interpolating_function( + (coords_new[0, :], coords_new[1, :], coords_new[2, :]) + ) return result.reshape(vol_ref.shape) @@ -606,7 +734,9 @@ def get_ras_axes(aff, n_dims=3): return img_ras_axes -def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True): +def align_volume_to_ref( + volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True +): """This function aligns a volume to a reference orientation (axis and direction) specified by an affine matrix. :param volume: a numpy array :param aff: affine matrix of the floating volume @@ -638,14 +768,17 @@ def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None if ras_axes_flo[i] != ras_axes_ref[i]: new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i]) swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i]) - ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx] + ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ( + ras_axes_flo[i], + ras_axes_flo[swapped_axis_idx], + ) # align directions dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0) for i in range(n_dims): if dot_products[i] < 0: new_volume = np.flip(new_volume, axis=i) - aff_flo[:, i] = - aff_flo[:, i] + aff_flo[:, i] = -aff_flo[:, i] aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1) if return_aff: @@ -666,17 +799,21 @@ def blur_volume(volume, sigma, mask=None): # initialisation new_volume = volume.copy() n_dims, _ = utils.get_dims(new_volume.shape) - sigma = utils.reformat_to_list(sigma, length=n_dims, dtype='float') + sigma = utils.reformat_to_list(sigma, length=n_dims, dtype="float") # blur image - new_volume = gaussian_filter(new_volume, sigma=sigma, mode='nearest') # nearest refers to edge padding + new_volume = gaussian_filter( + new_volume, sigma=sigma, mode="nearest" + ) # nearest refers to edge padding # correct edge effect if mask is not None if mask is not None: - assert new_volume.shape == mask.shape, 'volume and mask should have the same dimensions: ' \ - 'got {0} and {1}'.format(new_volume.shape, mask.shape) + assert new_volume.shape == mask.shape, ( + "volume and mask should have the same dimensions: " + "got {0} and {1}".format(new_volume.shape, mask.shape) + ) mask = (mask > 0) * 1.0 - blurred_mask = gaussian_filter(mask, sigma=sigma, mode='nearest') + blurred_mask = gaussian_filter(mask, sigma=sigma, mode="nearest") new_volume = new_volume / (blurred_mask + 1e-6) new_volume[mask == 0] = 0 @@ -685,8 +822,15 @@ def blur_volume(volume, sigma, mask=None): # --------------------------------------------------- edit label map --------------------------------------------------- -def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, use_nearest_label=False, - remove_zero=False, smooth=False): + +def correct_label_map( + labels, + list_incorrect_labels, + list_correct_labels=None, + use_nearest_label=False, + remove_zero=False, + smooth=False, +): """This function corrects specified label values in a label map by either a list of given values, or by the nearest label. :param labels: a 2d or 3d label map @@ -703,27 +847,39 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u :return: corrected label map """ - assert (list_correct_labels is not None) | use_nearest_label, \ - 'please provide a list of correct labels, or set use_nearest_label to True.' - assert (list_correct_labels is None) | (not use_nearest_label), \ - 'cannot provide a list of correct values and set use_nearest_label to True' + assert ( + list_correct_labels is not None + ) | use_nearest_label, ( + "please provide a list of correct labels, or set use_nearest_label to True." + ) + assert (list_correct_labels is None) | ( + not use_nearest_label + ), "cannot provide a list of correct values and set use_nearest_label to True" # initialisation new_labels = labels.copy() - list_incorrect_labels = utils.reformat_to_list(utils.load_array_if_path(list_incorrect_labels)) + list_incorrect_labels = utils.reformat_to_list( + utils.load_array_if_path(list_incorrect_labels) + ) volume_labels = np.unique(labels) n_dims, _ = utils.get_dims(labels.shape) # use list of correct values if list_correct_labels is not None: - list_correct_labels = utils.reformat_to_list(utils.load_array_if_path(list_correct_labels)) + list_correct_labels = utils.reformat_to_list( + utils.load_array_if_path(list_correct_labels) + ) # loop over label values - for incorrect_label, correct_label in zip(list_incorrect_labels, list_correct_labels): + for incorrect_label, correct_label in zip( + list_incorrect_labels, list_correct_labels + ): if incorrect_label in volume_labels: # only one possible value to replace with - if isinstance(correct_label, (int, float, np.int64, np.int32, np.int16, np.int8)): + if isinstance( + correct_label, (int, float, np.int64, np.int32, np.int16, np.int8) + ): incorrect_voxels = np.where(labels == incorrect_label) new_labels[incorrect_voxels] = correct_label @@ -732,8 +888,12 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u # make sure at least one correct label is present if not any([lab in volume_labels for lab in correct_label]): - print('no correct values found in volume, please adjust: ' - 'incorrect: {}, correct: {}'.format(incorrect_label, correct_label)) + print( + "no correct values found in volume, please adjust: " + "incorrect: {}, correct: {}".format( + incorrect_label, correct_label + ) + ) # crop around incorrect label until we find incorrect labels correct_label_not_found = True @@ -741,21 +901,34 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u tmp_labels = None crop = None while correct_label_not_found: - tmp_labels, crop = crop_volume_around_region(labels, - masking_labels=incorrect_label, - margin=10 * margin_mult) - correct_label_not_found = not any([lab in np.unique(tmp_labels) for lab in correct_label]) + tmp_labels, crop = crop_volume_around_region( + labels, + masking_labels=incorrect_label, + margin=10 * margin_mult, + ) + correct_label_not_found = not any( + [lab in np.unique(tmp_labels) for lab in correct_label] + ) margin_mult += 1 # calculate distance maps for all new label candidates incorrect_voxels = np.where(tmp_labels == incorrect_label) - distance_map_list = [distance_transform_edt(tmp_labels != lab) for lab in correct_label] - distances_correct = np.stack([dist[incorrect_voxels] for dist in distance_map_list]) + distance_map_list = [ + distance_transform_edt(tmp_labels != lab) + for lab in correct_label + ] + distances_correct = np.stack( + [dist[incorrect_voxels] for dist in distance_map_list] + ) # select nearest values and use them to correct label map idx_correct_lab = np.argmin(distances_correct, axis=0) - incorrect_voxels = tuple([incorrect_voxels[i] + crop[i] for i in range(n_dims)]) - new_labels[incorrect_voxels] = np.array(correct_label)[idx_correct_lab] + incorrect_voxels = tuple( + [incorrect_voxels[i] + crop[i] for i in range(n_dims)] + ) + new_labels[incorrect_voxels] = np.array(correct_label)[ + idx_correct_lab + ] # use nearest label else: @@ -766,21 +939,27 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u # loop around regions components, n_components = scipy_label(labels == incorrect_label) - loop_info = utils.LoopInfo(n_components + 1, 100, 'correcting') + loop_info = utils.LoopInfo(n_components + 1, 100, "correcting") for i in range(1, n_components + 1): loop_info.update(i) # crop each region - _, crop = crop_volume_around_region(components, masking_labels=i, margin=1) + _, crop = crop_volume_around_region( + components, masking_labels=i, margin=1 + ) tmp_labels = crop_volume_with_idx(labels, crop) tmp_new_labels = crop_volume_with_idx(new_labels, crop) # list all possible correct labels correct_labels = np.unique(tmp_labels) for il in list_incorrect_labels: - correct_labels = np.delete(correct_labels, np.where(correct_labels == il)) + correct_labels = np.delete( + correct_labels, np.where(correct_labels == il) + ) if remove_zero: - correct_labels = np.delete(correct_labels, np.where(correct_labels == 0)) + correct_labels = np.delete( + correct_labels, np.where(correct_labels == 0) + ) # replace incorrect voxels by new value incorrect_voxels = np.where(tmp_labels == incorrect_label) @@ -788,18 +967,31 @@ def correct_label_map(labels, list_incorrect_labels, list_correct_labels=None, u tmp_new_labels[incorrect_voxels] = -1 else: if len(correct_labels) == 1: - idx_correct_lab = np.zeros(len(incorrect_voxels[0]), dtype='int32') + idx_correct_lab = np.zeros( + len(incorrect_voxels[0]), dtype="int32" + ) else: - distance_map_list = [distance_transform_edt(tmp_labels != lab) for lab in correct_labels] - distances_correct = np.stack([dist[incorrect_voxels] for dist in distance_map_list]) + distance_map_list = [ + distance_transform_edt(tmp_labels != lab) + for lab in correct_labels + ] + distances_correct = np.stack( + [dist[incorrect_voxels] for dist in distance_map_list] + ) idx_correct_lab = np.argmin(distances_correct, axis=0) - tmp_new_labels[incorrect_voxels] = np.array(correct_labels)[idx_correct_lab] + tmp_new_labels[incorrect_voxels] = np.array(correct_labels)[ + idx_correct_lab + ] # paste back if n_dims == 2: - new_labels[crop[0]:crop[2], crop[1]:crop[3], ...] = tmp_new_labels + new_labels[crop[0] : crop[2], crop[1] : crop[3], ...] = ( + tmp_new_labels + ) else: - new_labels[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] = tmp_new_labels + new_labels[ + crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ... + ] = tmp_new_labels # smoothing if smooth: @@ -843,18 +1035,20 @@ def smooth_label_map(labels, kernel, labels_list=None, print_progress=0): """ # get info labels_shape = labels.shape - unique_labels = np.unique(labels).astype('int32') + unique_labels = np.unique(labels).astype("int32") if labels_list is None: labels_list = unique_labels new_labels = mask_new_labels = None else: labels_to_keep = [lab for lab in unique_labels if lab not in labels_list] - new_labels, mask_new_labels = mask_label_map(labels, labels_to_keep, return_mask=True) + new_labels, mask_new_labels = mask_label_map( + labels, labels_to_keep, return_mask=True + ) # loop through label values count = np.zeros(labels_shape) - labels_smoothed = np.zeros(labels_shape, dtype='int') - loop_info = utils.LoopInfo(len(labels_list), print_progress, 'smoothing') + labels_smoothed = np.zeros(labels_shape, dtype="int") + loop_info = utils.LoopInfo(len(labels_list), print_progress, "smoothing") for la, label in enumerate(labels_list): if print_progress: loop_info.update(la) @@ -867,7 +1061,7 @@ def smooth_label_map(labels, kernel, labels_list=None, print_progress=0): idx = n_neighbours > count count[idx] = n_neighbours[idx] labels_smoothed[idx] = label - labels_smoothed = labels_smoothed.astype('int32') + labels_smoothed = labels_smoothed.astype("int32") if new_labels is None: new_labels = labels_smoothed @@ -877,7 +1071,14 @@ def smooth_label_map(labels, kernel, labels_list=None, print_progress=0): return new_labels -def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, model=None, return_model=False): +def erode_label_map( + labels, + labels_to_erode, + erosion_factors=1.0, + gpu=False, + model=None, + return_model=False, +): """Erode a given set of label values within a label map. :param labels: a 2d or 3d label map :param labels_to_erode: list of label values to erode @@ -893,17 +1094,23 @@ def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, mode # reformat labels_to_erode and erode new_labels = labels.copy() labels_to_erode = utils.reformat_to_list(labels_to_erode) - erosion_factors = utils.reformat_to_list(erosion_factors, length=len(labels_to_erode)) + erosion_factors = utils.reformat_to_list( + erosion_factors, length=len(labels_to_erode) + ) labels_shape = list(new_labels.shape) n_dims, _ = utils.get_dims(labels_shape) # loop over labels to erode for label_to_erode, erosion_factor in zip(labels_to_erode, erosion_factors): - assert erosion_factor > 0, 'all erosion factors should be strictly positive, had {}'.format(erosion_factor) + assert ( + erosion_factor > 0 + ), "all erosion factors should be strictly positive, had {}".format( + erosion_factor + ) # get mask of current label value - mask = (new_labels == label_to_erode) + mask = new_labels == label_to_erode # erode as usual if erosion factor is int if int(erosion_factor) == erosion_factor: @@ -914,12 +1121,14 @@ def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, mode else: if gpu: if model is None: - mask_in = KL.Input(shape=labels_shape + [1], dtype='float32') + mask_in = KL.Input(shape=labels_shape + [1], dtype="float32") blurred_mask = GaussianBlur([1] * 3)(mask_in) model = Model(inputs=mask_in, outputs=blurred_mask) - eroded_mask = model.predict(utils.add_axis(np.float32(mask), axis=[0, -1])) + eroded_mask = model.predict( + utils.add_axis(np.float32(mask), axis=[0, -1]) + ) else: - eroded_mask = blur_volume(np.array(mask, dtype='float32'), 1) + eroded_mask = blur_volume(np.array(mask, dtype="float32"), 1) eroded_mask = np.squeeze(eroded_mask) > erosion_factor # crop label map and mask around values to change @@ -930,16 +1139,28 @@ def erode_label_map(labels, labels_to_erode, erosion_factors=1., gpu=False, mode # calculate distance maps for all labels in cropped_labels labels_list = np.unique(cropped_labels) labels_list = labels_list[labels_list != label_to_erode] - list_dist_maps = [distance_transform_edt(np.logical_not(cropped_labels == la)) for la in labels_list] - candidate_distances = np.stack([dist[cropped_lab_mask] for dist in list_dist_maps]) + list_dist_maps = [ + distance_transform_edt(np.logical_not(cropped_labels == la)) + for la in labels_list + ] + candidate_distances = np.stack( + [dist[cropped_lab_mask] for dist in list_dist_maps] + ) # select nearest value and put cropped labels back to full label map idx_correct_lab = np.argmin(candidate_distances, axis=0) cropped_labels[cropped_lab_mask] = np.array(labels_list)[idx_correct_lab] if n_dims == 2: - new_labels[cropping[0]:cropping[2], cropping[1]:cropping[3], ...] = cropped_labels + new_labels[cropping[0] : cropping[2], cropping[1] : cropping[3], ...] = ( + cropped_labels + ) elif n_dims == 3: - new_labels[cropping[0]:cropping[3], cropping[1]:cropping[4], cropping[2]:cropping[5], ...] = cropped_labels + new_labels[ + cropping[0] : cropping[3], + cropping[1] : cropping[4], + cropping[2] : cropping[5], + ..., + ] = cropped_labels if return_model: return new_labels, model @@ -953,10 +1174,16 @@ def get_largest_connected_component(mask, structure=None): :param structure: numpy array defining the connectivity. """ components, n_components = scipy_label(mask, structure) - return components == np.argmax(np.bincount(components.flat)[1:]) + 1 if n_components > 0 else mask.copy() + return ( + components == np.argmax(np.bincount(components.flat)[1:]) + 1 + if n_components > 0 + else mask.copy() + ) -def compute_hard_volumes(labels, voxel_volume=1., label_list=None, skip_background=True): +def compute_hard_volumes( + labels, voxel_volume=1.0, label_list=None, skip_background=True +): """Compute hard volumes in a label map. :param labels: a label map :param voxel_volume: (optional) volume of voxel. Default is 1 (i.e. returned volumes are voxel counts). @@ -969,7 +1196,7 @@ def compute_hard_volumes(labels, voxel_volume=1., label_list=None, skip_backgrou """ # initialisation - subject_label_list = utils.reformat_to_list(np.unique(labels), dtype='int') + subject_label_list = utils.reformat_to_list(np.unique(labels), dtype="int") if label_list is None: label_list = subject_label_list else: @@ -996,7 +1223,8 @@ def compute_distance_map(labels, masking_labels=None, crop_margin=None): for these labels only. Default is None, where all positive values are considered. :param crop_margin: (optional) margin with which to crop the input label maps around the labels for which we want to compute the distance maps. - :return: a distance map with positive values inside the considered regions, and negative values outside.""" + :return: a distance map with positive values inside the considered regions, and negative values outside. + """ n_dims, _ = utils.get_dims(labels.shape) @@ -1010,7 +1238,7 @@ def compute_distance_map(labels, masking_labels=None, crop_margin=None): # mask label map around specify values if masking_labels is not None: masking_labels = utils.reformat_to_list(masking_labels) - mask = np.zeros(tmp_labels.shape, dtype='bool') + mask = np.zeros(tmp_labels.shape, dtype="bool") for masking_label in masking_labels: mask = mask | tmp_labels == masking_label else: @@ -1020,17 +1248,22 @@ def compute_distance_map(labels, masking_labels=None, crop_margin=None): # compute distances dist_in = distance_transform_edt(mask) dist_in = np.where(mask, dist_in - 0.5, dist_in) - dist_out = - distance_transform_edt(not_mask) + dist_out = -distance_transform_edt(not_mask) dist_out = np.where(not_mask, dist_out + 0.5, dist_out) tmp_dist = dist_in + dist_out # put back in original matrix if we cropped if crop_idx is not None: - dist = np.min(tmp_dist) * np.ones(labels.shape, dtype='float32') + dist = np.min(tmp_dist) * np.ones(labels.shape, dtype="float32") if n_dims == 3: - dist[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...] = tmp_dist + dist[ + crop_idx[0] : crop_idx[3], + crop_idx[1] : crop_idx[4], + crop_idx[2] : crop_idx[5], + ..., + ] = tmp_dist elif n_dims == 2: - dist[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...] = tmp_dist + dist[crop_idx[0] : crop_idx[2], crop_idx[1] : crop_idx[3], ...] = tmp_dist else: dist = tmp_dist @@ -1039,8 +1272,20 @@ def compute_distance_map(labels, masking_labels=None, crop_margin=None): # ------------------------------------------------- edit volumes in dir ------------------------------------------------ -def mask_images_in_dir(image_dir, result_dir, mask_dir=None, threshold=0.1, dilate=0, erode=0, fill_holes=False, - masking_value=0, write_mask=False, mask_result_dir=None, recompute=True): + +def mask_images_in_dir( + image_dir, + result_dir, + mask_dir=None, + threshold=0.1, + dilate=0, + erode=0, + fill_holes=False, + masking_value=0, + write_mask=False, + mask_result_dir=None, + recompute=True, +): """Mask all volumes in a folder, either with masks in a specified folder, or by keeping only the intensity values above a specified threshold. :param image_dir: path of directory with images to mask @@ -1072,7 +1317,7 @@ def mask_images_in_dir(image_dir, result_dir, mask_dir=None, threshold=0.1, dila path_masks = [None] * len(path_images) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'masking', True) + loop_info = utils.LoopInfo(len(path_images), 10, "masking", True) for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): loop_info.update(idx) @@ -1084,22 +1329,41 @@ def mask_images_in_dir(image_dir, result_dir, mask_dir=None, threshold=0.1, dila mask = utils.load_volume(path_mask) else: mask = None - im = mask_volume(im, mask, threshold, dilate, erode, fill_holes, masking_value, write_mask) + im = mask_volume( + im, + mask, + threshold, + dilate, + erode, + fill_holes, + masking_value, + write_mask, + ) # write mask if necessary if write_mask: - assert mask_result_dir is not None, 'if write_mask is True, mask_result_dir has to be specified as well' - mask_result_path = os.path.join(mask_result_dir, os.path.basename(path_image)) + assert ( + mask_result_dir is not None + ), "if write_mask is True, mask_result_dir has to be specified as well" + mask_result_path = os.path.join( + mask_result_dir, os.path.basename(path_image) + ) utils.save_volume(im[1], aff, h, mask_result_path) utils.save_volume(im[0], aff, h, path_result) else: utils.save_volume(im, aff, h, path_result) -def rescale_images_in_dir(image_dir, result_dir, - new_min=0, new_max=255, - min_percentile=2, max_percentile=98, use_positive_only=True, - recompute=True): +def rescale_images_in_dir( + image_dir, + result_dir, + new_min=0, + new_max=255, + min_percentile=2, + max_percentile=98, + use_positive_only=True, + recompute=True, +): """This function linearly rescales all volumes in image_dir between new_min and new_max. :param image_dir: path of directory with images to rescale :param result_dir: path of directory where rescaled images will be writen @@ -1118,18 +1382,22 @@ def rescale_images_in_dir(image_dir, result_dir, # loop over images path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'rescaling', True) + loop_info = utils.LoopInfo(len(path_images), 10, "rescaling", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) path_result = os.path.join(result_dir, os.path.basename(path_image)) if (not os.path.isfile(path_result)) | recompute: im, aff, h = utils.load_volume(path_image, im_only=False) - im = rescale_volume(im, new_min, new_max, min_percentile, max_percentile, use_positive_only) + im = rescale_volume( + im, new_min, new_max, min_percentile, max_percentile, use_positive_only + ) utils.save_volume(im, aff, h, path_result) -def crop_images_in_dir(image_dir, result_dir, cropping_margin=None, cropping_shape=None, recompute=True): +def crop_images_in_dir( + image_dir, result_dir, cropping_margin=None, cropping_shape=None, recompute=True +): """Crop all volumes in a folder by a given margin, or to a given shape. :param image_dir: path of directory with images to rescale :param result_dir: path of directory where cropped images will be writen @@ -1145,7 +1413,7 @@ def crop_images_in_dir(image_dir, result_dir, cropping_margin=None, cropping_sha # loop over images and masks path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) @@ -1157,13 +1425,15 @@ def crop_images_in_dir(image_dir, result_dir, cropping_margin=None, cropping_sha utils.save_volume(volume, aff, h, path_result) -def crop_images_around_region_in_dir(image_dir, - result_dir, - mask_dir=None, - threshold=0.1, - masking_labels=None, - crop_margin=5, - recompute=True): +def crop_images_around_region_in_dir( + image_dir, + result_dir, + mask_dir=None, + threshold=0.1, + masking_labels=None, + crop_margin=5, + recompute=True, +): """Crop all volumes in a folder around a region, which is defined for each volume by a mask obtained by either 1) directly providing it as input 2) thresholding the input volume @@ -1189,7 +1459,7 @@ def crop_images_around_region_in_dir(image_dir, path_masks = [None] * len(path_images) # loop over images and masks - loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): loop_info.update(idx) @@ -1201,11 +1471,15 @@ def crop_images_around_region_in_dir(image_dir, mask = utils.load_volume(path_mask) else: mask = None - volume, cropping, aff = crop_volume_around_region(volume, mask, threshold, masking_labels, crop_margin, aff) + volume, cropping, aff = crop_volume_around_region( + volume, mask, threshold, masking_labels, crop_margin, aff + ) utils.save_volume(volume, aff, h, path_result) -def pad_images_in_dir(image_dir, result_dir, max_shape=None, padding_value=0, recompute=True): +def pad_images_in_dir( + image_dir, result_dir, max_shape=None, padding_value=0, recompute=True +): """Pads all the volumes in a folder to the same shape (either provided or computed). :param image_dir: path of directory with images to pad :param result_dir: path of directory where padded images will be writen @@ -1227,11 +1501,13 @@ def pad_images_in_dir(image_dir, result_dir, max_shape=None, padding_value=0, re max_shape, aff, _, _, h, _ = utils.get_volume_info(path_images[0]) for path_image in path_images[1:]: image_shape, aff, _, _, h, _ = utils.get_volume_info(path_image) - max_shape = tuple(np.maximum(np.asarray(max_shape), np.asarray(image_shape))) + max_shape = tuple( + np.maximum(np.asarray(max_shape), np.asarray(image_shape)) + ) max_shape = np.array(max_shape) # loop over label maps - loop_info = utils.LoopInfo(len(path_images), 10, 'padding', True) + loop_info = utils.LoopInfo(len(path_images), 10, "padding", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) @@ -1245,7 +1521,9 @@ def pad_images_in_dir(image_dir, result_dir, max_shape=None, padding_value=0, re return max_shape -def flip_images_in_dir(image_dir, result_dir, axis=None, direction=None, recompute=True): +def flip_images_in_dir( + image_dir, result_dir, axis=None, direction=None, recompute=True +): """Flip all images in a directory along a specified axis. If unknown, this axis can be replaced by an anatomical direction. :param image_dir: path of directory with images to flip @@ -1260,7 +1538,7 @@ def flip_images_in_dir(image_dir, result_dir, axis=None, direction=None, recompu # loop over images path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'flipping', True) + loop_info = utils.LoopInfo(len(path_images), 10, "flipping", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) @@ -1272,7 +1550,9 @@ def flip_images_in_dir(image_dir, result_dir, axis=None, direction=None, recompu utils.save_volume(im, aff, h, path_result) -def align_images_in_dir(image_dir, result_dir, aff_ref=None, path_ref=None, recompute=True): +def align_images_in_dir( + image_dir, result_dir, aff_ref=None, path_ref=None, recompute=True +): """This function aligns all images in image_dir to a reference orientation (axes and directions). This reference orientation can be directly provided as an affine matrix, or can be specified by a reference volume. If neither are provided, the reference orientation is assumed to be an identity matrix. @@ -1291,9 +1571,14 @@ def align_images_in_dir(image_dir, result_dir, aff_ref=None, path_ref=None, reco # read reference affine matrix if path_ref is not None: - assert aff_ref is None, 'cannot provide aff_ref and path_ref together.' + assert aff_ref is None, "cannot provide aff_ref and path_ref together." basename = os.path.basename(path_ref) - if ('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename): + if ( + (".nii.gz" in basename) + | (".nii" in basename) + | (".mgz" in basename) + | (".npz" in basename) + ): _, aff_ref, _ = utils.load_volume(path_ref, im_only=False) path_refs = [None] * len(path_images) else: @@ -1306,7 +1591,7 @@ def align_images_in_dir(image_dir, result_dir, aff_ref=None, path_ref=None, reco path_refs = [None] * len(path_images) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'aligning', True) + loop_info = utils.LoopInfo(len(path_images), 10, "aligning", True) for idx, (path_image, path_ref) in enumerate(zip(path_images, path_refs)): loop_info.update(idx) @@ -1331,7 +1616,7 @@ def correct_nans_images_in_dir(image_dir, result_dir, recompute=True): # loop over images path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'correcting', True) + loop_info = utils.LoopInfo(len(path_images), 10, "correcting", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) @@ -1343,7 +1628,9 @@ def correct_nans_images_in_dir(image_dir, result_dir, recompute=True): utils.save_volume(im, aff, h, path_result) -def blur_images_in_dir(image_dir, result_dir, sigma, mask_dir=None, gpu=False, recompute=True): +def blur_images_in_dir( + image_dir, result_dir, sigma, mask_dir=None, gpu=False, recompute=True +): """This function blurs all the images in image_dir with kernels of the specified std deviations. :param image_dir: path of directory with images to blur :param result_dir: path of directory where blurred images will be writen @@ -1368,17 +1655,21 @@ def blur_images_in_dir(image_dir, result_dir, sigma, mask_dir=None, gpu=False, r # loop over images previous_model_input_shape = None model = None - loop_info = utils.LoopInfo(len(path_images), 10, 'blurring', True) + loop_info = utils.LoopInfo(len(path_images), 10, "blurring", True) for idx, (path_image, path_mask) in enumerate(zip(path_images, path_masks)): loop_info.update(idx) # load image path_result = os.path.join(result_dir, os.path.basename(path_image)) if (not os.path.isfile(path_result)) | recompute: - im, im_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_image, return_volume=True) + im, im_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_image, return_volume=True + ) if path_mask is not None: mask = utils.load_volume(path_mask) - assert mask.shape == im.shape, 'mask and image should have the same shape' + assert ( + mask.shape == im.shape + ), "mask and image should have the same shape" else: mask = None @@ -1391,13 +1682,17 @@ def blur_images_in_dir(image_dir, result_dir, sigma, mask_dir=None, gpu=False, r if mask is None: image = GaussianBlur(sigma=sigma)(inputs[0]) else: - inputs.append(KL.Input(shape=im_shape + [1], dtype='float32')) + inputs.append(KL.Input(shape=im_shape + [1], dtype="float32")) image = GaussianBlur(sigma=sigma, use_mask=True)(inputs) model = Model(inputs=inputs, outputs=image) if mask is None: im = np.squeeze(model.predict(utils.add_axis(im, axis=[0, -1]))) else: - im = np.squeeze(model.predict([utils.add_axis(im, [0, -1]), utils.add_axis(mask, [0, -1])])) + im = np.squeeze( + model.predict( + [utils.add_axis(im, [0, -1]), utils.add_axis(mask, [0, -1])] + ) + ) else: im = blur_volume(im, sigma, mask=mask) utils.save_volume(im, aff, h, path_result) @@ -1414,7 +1709,9 @@ def create_mutlimodal_images(list_channel_dir, result_dir, recompute=True): # create result dir utils.mkdir(result_dir) - assert isinstance(list_channel_dir, (list, tuple)), 'list_channel_dir should be a list or a tuple' + assert isinstance( + list_channel_dir, (list, tuple) + ), "list_channel_dir should be a list or a tuple" # gather path of all images for all channels list_channel_paths = [utils.list_images_in_folder(d) for d in list_channel_dir] @@ -1422,27 +1719,33 @@ def create_mutlimodal_images(list_channel_dir, result_dir, recompute=True): n_channels = len(list_channel_dir) for channel_paths in list_channel_paths: if len(channel_paths) != n_images: - raise ValueError('all directories should have the same number of files') + raise ValueError("all directories should have the same number of files") # loop over images - loop_info = utils.LoopInfo(n_images, 10, 'processing', True) + loop_info = utils.LoopInfo(n_images, 10, "processing", True) for idx in range(n_images): loop_info.update(idx) # stack all channels and save multichannel image - path_result = os.path.join(result_dir, os.path.basename(list_channel_paths[0][idx])) + path_result = os.path.join( + result_dir, os.path.basename(list_channel_paths[0][idx]) + ) if (not os.path.isfile(path_result)) | recompute: list_channels = list() tmp_aff = None tmp_h = None for channel_idx in range(n_channels): - tmp_channel, tmp_aff, tmp_h = utils.load_volume(list_channel_paths[channel_idx][idx], im_only=False) + tmp_channel, tmp_aff, tmp_h = utils.load_volume( + list_channel_paths[channel_idx][idx], im_only=False + ) list_channels.append(tmp_channel) im = np.stack(list_channels, axis=-1) utils.save_volume(im, tmp_aff, tmp_h, path_result) -def convert_images_in_dir_to_nifty(image_dir, result_dir, aff=None, ref_aff_dir=None, recompute=True): +def convert_images_in_dir_to_nifty( + image_dir, result_dir, aff=None, ref_aff_dir=None, recompute=True +): """Converts all images in image_dir to nifty format. :param image_dir: path of directory with images to convert :param result_dir: path of directory where converted images will be writen @@ -1464,14 +1767,19 @@ def convert_images_in_dir_to_nifty(image_dir, result_dir, aff=None, ref_aff_dir= path_ref_images = [None] * len(path_images) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'converting', True) + loop_info = utils.LoopInfo(len(path_images), 10, "converting", True) for idx, (path_image, path_ref) in enumerate(zip(path_images, path_ref_images)): loop_info.update(idx) # convert images to nifty format - path_result = os.path.join(result_dir, os.path.basename(utils.strip_extension(path_image))) + '.nii.gz' + path_result = ( + os.path.join( + result_dir, os.path.basename(utils.strip_extension(path_image)) + ) + + ".nii.gz" + ) if (not os.path.isfile(path_result)) | recompute: - if utils.get_image_extension(path_image) == 'nii.gz': + if utils.get_image_extension(path_image) == "nii.gz": shutil.copy2(path_image, path_result) else: im, tmp_aff, h = utils.load_volume(path_image, im_only=False) @@ -1482,15 +1790,17 @@ def convert_images_in_dir_to_nifty(image_dir, result_dir, aff=None, ref_aff_dir= utils.save_volume(im, tmp_aff, h, path_result) -def mri_convert_images_in_dir(image_dir, - result_dir, - interpolation=None, - reference_dir=None, - same_reference=False, - voxsize=None, - path_freesurfer='/usr/local/freesurfer', - mri_convert_path='/usr/local/freesurfer/bin/mri_convert', - recompute=True): +def mri_convert_images_in_dir( + image_dir, + result_dir, + interpolation=None, + reference_dir=None, + same_reference=False, + voxsize=None, + path_freesurfer="/usr/local/freesurfer", + mri_convert_path="/usr/local/freesurfer/bin/mri_convert", + recompute=True, +): """This function launches mri_convert on all images contained in image_dir, and writes the results in result_dir. The interpolation type can be specified (i.e. 'nearest'), as well as a folder containing references for resampling. reference_dir can be the path of a single *image* if same_reference=True. @@ -1512,9 +1822,9 @@ def mri_convert_images_in_dir(image_dir, utils.mkdir(result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - mri_convert = mri_convert_path + ' ' + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = mri_convert_path + " " # list images path_images = utils.list_images_in_folder(image_dir) @@ -1523,36 +1833,42 @@ def mri_convert_images_in_dir(image_dir, path_references = [reference_dir] * len(path_images) else: path_references = utils.list_images_in_folder(reference_dir) - assert len(path_references) == len(path_images), 'different number of files in image_dir and reference_dir' + assert len(path_references) == len( + path_images + ), "different number of files in image_dir and reference_dir" else: path_references = [None] * len(path_images) # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'converting', True) - for idx, (path_image, path_reference) in enumerate(zip(path_images, path_references)): + loop_info = utils.LoopInfo(len(path_images), 10, "converting", True) + for idx, (path_image, path_reference) in enumerate( + zip(path_images, path_references) + ): loop_info.update(idx) # convert image path_result = os.path.join(result_dir, os.path.basename(path_image)) if (not os.path.isfile(path_result)) | recompute: - cmd = mri_convert + path_image + ' ' + path_result + ' -odt float' + cmd = mri_convert + path_image + " " + path_result + " -odt float" if interpolation is not None: - cmd += ' -rt ' + interpolation + cmd += " -rt " + interpolation if reference_dir is not None: - cmd += ' -rl ' + path_reference + cmd += " -rl " + path_reference if voxsize is not None: - voxsize = utils.reformat_to_list(voxsize, dtype='float') - cmd += ' --voxsize ' + ' '.join([str(np.around(v, 3)) for v in voxsize]) + voxsize = utils.reformat_to_list(voxsize, dtype="float") + cmd += " --voxsize " + " ".join([str(np.around(v, 3)) for v in voxsize]) os.system(cmd) -def samseg_images_in_dir(image_dir, - result_dir, - atlas_dir=None, - threads=4, - path_freesurfer='/usr/local/freesurfer', - keep_segm_only=True, - recompute=True): +def samseg_images_in_dir( + image_dir, + result_dir, + atlas_dir=None, + threads=4, + path_freesurfer="/usr/local/freesurfer", + keep_segm_only=True, + recompute=True, +): """This function launches samseg for all images contained in image_dir and writes the results in result_dir. If keep_segm_only=True, the result segmentation is copied in result_dir and SAMSEG's intermediate result dir is deleted. @@ -1570,29 +1886,42 @@ def samseg_images_in_dir(image_dir, utils.mkdir(result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - path_samseg = os.path.join(path_freesurfer, 'bin', 'run_samseg') + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + path_samseg = os.path.join(path_freesurfer, "bin", "run_samseg") # loop over images path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) for idx, path_image in enumerate(path_images): loop_info.update(idx) # build path_result - path_im_result_dir = os.path.join(result_dir, utils.strip_extension(os.path.basename(path_image))) - path_samseg_result = os.path.join(path_im_result_dir, 'seg.mgz') + path_im_result_dir = os.path.join( + result_dir, utils.strip_extension(os.path.basename(path_image)) + ) + path_samseg_result = os.path.join(path_im_result_dir, "seg.mgz") if keep_segm_only: - path_result = os.path.join(result_dir, utils.strip_extension(os.path.basename(path_image)) + '_seg.mgz') + path_result = os.path.join( + result_dir, + utils.strip_extension(os.path.basename(path_image)) + "_seg.mgz", + ) else: path_result = path_samseg_result # run samseg if (not os.path.isfile(path_result)) | recompute: - cmd = utils.mkcmd(path_samseg, '-i', path_image, '-o', path_im_result_dir, '--threads', threads) + cmd = utils.mkcmd( + path_samseg, + "-i", + path_image, + "-o", + path_im_result_dir, + "--threads", + threads, + ) if atlas_dir is not None: - cmd = utils.mkcmd(cmd, '-a', atlas_dir) + cmd = utils.mkcmd(cmd, "-a", atlas_dir) os.system(cmd) # move segmentation to result_dir if necessary @@ -1603,18 +1932,20 @@ def samseg_images_in_dir(image_dir, shutil.rmtree(path_im_result_dir) -def niftyreg_images_in_dir(image_dir, - reference_dir, - nifty_reg_function='reg_resample', - input_transformation_dir=None, - result_dir=None, - result_transformation_dir=None, - interpolation=None, - same_floating=False, - same_reference=False, - same_transformation=False, - path_nifty_reg='/home/benjamin/Softwares/niftyreg-gpu/build/reg-apps', - recompute=True): +def niftyreg_images_in_dir( + image_dir, + reference_dir, + nifty_reg_function="reg_resample", + input_transformation_dir=None, + result_dir=None, + result_transformation_dir=None, + interpolation=None, + same_floating=False, + same_reference=False, + same_transformation=False, + path_nifty_reg="/home/benjamin/Softwares/niftyreg-gpu/build/reg-apps", + recompute=True, +): """This function launches one of niftyreg functions (reg_aladin, reg_f3d, reg_resample) on all images contained in image_dir. :param image_dir: path of directory with images to register. Can also be a single image, in that case set @@ -1650,54 +1981,70 @@ def niftyreg_images_in_dir(image_dir, path_images = utils.list_images_in_folder(image_dir) path_references = utils.list_images_in_folder(reference_dir) if same_reference: - path_references = utils.reformat_to_list(path_references, length=len(path_images)) + path_references = utils.reformat_to_list( + path_references, length=len(path_images) + ) if same_floating: path_images = utils.reformat_to_list(path_images, length=len(path_references)) - assert len(path_references) == len(path_images), 'different number of files in image_dir and reference_dir' + assert len(path_references) == len( + path_images + ), "different number of files in image_dir and reference_dir" # list input transformations if input_transformation_dir is not None: if same_transformation: - path_input_transfs = utils.reformat_to_list(input_transformation_dir, length=len(path_images)) + path_input_transfs = utils.reformat_to_list( + input_transformation_dir, length=len(path_images) + ) else: path_input_transfs = utils.list_files(input_transformation_dir) - assert len(path_input_transfs) == len(path_images), 'different number of transformations and images' + assert len(path_input_transfs) == len( + path_images + ), "different number of transformations and images" else: path_input_transfs = [None] * len(path_images) # define flag input trans if input_transformation_dir is not None: - if nifty_reg_function == 'reg_aladin': - flag_input_trans = '-inaff' - elif nifty_reg_function == 'reg_f3d': - flag_input_trans = '-aff' - elif nifty_reg_function == 'reg_resample': - flag_input_trans = '-trans' + if nifty_reg_function == "reg_aladin": + flag_input_trans = "-inaff" + elif nifty_reg_function == "reg_f3d": + flag_input_trans = "-aff" + elif nifty_reg_function == "reg_resample": + flag_input_trans = "-trans" else: - raise Exception('nifty_reg_function can only be "reg_aladin", "reg_f3d", or "reg_resample"') + raise Exception( + 'nifty_reg_function can only be "reg_aladin", "reg_f3d", or "reg_resample"' + ) else: flag_input_trans = None # define flag result transformation if result_transformation_dir is not None: - if nifty_reg_function == 'reg_aladin': - flag_result_trans = '-aff' - elif nifty_reg_function == 'reg_f3d': - flag_result_trans = '-cpp' + if nifty_reg_function == "reg_aladin": + flag_result_trans = "-aff" + elif nifty_reg_function == "reg_f3d": + flag_result_trans = "-cpp" else: - raise Exception('result_transformation_dir can only be used with "reg_aladin" or "reg_f3d"') + raise Exception( + 'result_transformation_dir can only be used with "reg_aladin" or "reg_f3d"' + ) else: flag_result_trans = None # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) - for idx, (path_image, path_ref, path_input_trans) in enumerate(zip(path_images, - path_references, - path_input_transfs)): + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) + for idx, (path_image, path_ref, path_input_trans) in enumerate( + zip(path_images, path_references, path_input_transfs) + ): loop_info.update(idx) # define path registered image - name = os.path.basename(path_ref) if same_floating else os.path.basename(path_image) + name = ( + os.path.basename(path_ref) + if same_floating + else os.path.basename(path_image) + ) if result_dir is not None: path_result = os.path.join(result_dir, name) result_already_computed = os.path.isfile(path_result) @@ -1707,8 +2054,10 @@ def niftyreg_images_in_dir(image_dir, # define path resulting transformation if result_transformation_dir is not None: - if nifty_reg_function == 'reg_aladin': - path_result_trans = os.path.join(result_transformation_dir, utils.strip_extension(name) + '.txt') + if nifty_reg_function == "reg_aladin": + path_result_trans = os.path.join( + result_transformation_dir, utils.strip_extension(name) + ".txt" + ) result_trans_already_computed = os.path.isfile(path_result_trans) else: path_result_trans = os.path.join(result_transformation_dir, name) @@ -1717,30 +2066,36 @@ def niftyreg_images_in_dir(image_dir, path_result_trans = None result_trans_already_computed = True - if (not result_already_computed) | (not result_trans_already_computed) | recompute: + if ( + (not result_already_computed) + | (not result_trans_already_computed) + | recompute + ): # build main command - cmd = utils.mkcmd(nifty_reg, '-ref', path_ref, '-flo', path_image, '-pad 0') + cmd = utils.mkcmd(nifty_reg, "-ref", path_ref, "-flo", path_image, "-pad 0") # add options if path_result is not None: - cmd = utils.mkcmd(cmd, '-res', path_result) + cmd = utils.mkcmd(cmd, "-res", path_result) if flag_input_trans is not None: cmd = utils.mkcmd(cmd, flag_input_trans, path_input_trans) if flag_result_trans is not None: cmd = utils.mkcmd(cmd, flag_result_trans, path_result_trans) if interpolation is not None: - cmd = utils.mkcmd(cmd, '-inter', interpolation) + cmd = utils.mkcmd(cmd, "-inter", interpolation) # execute os.system(cmd) -def upsample_anisotropic_images(image_dir, - resample_image_result_dir, - resample_like_dir, - path_freesurfer='/usr/local/freesurfer/', - recompute=True): +def upsample_anisotropic_images( + image_dir, + resample_image_result_dir, + resample_like_dir, + path_freesurfer="/usr/local/freesurfer/", + recompute=True, +): """This function takes as input a set of LR images and resample them to HR with respect to reference images. :param image_dir: path of directory with input images (only uni-modal images supported) :param resample_image_result_dir: path of directory where resampled images will be writen @@ -1753,66 +2108,97 @@ def upsample_anisotropic_images(image_dir, utils.mkdir(resample_image_result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") # list images and labels path_images = utils.list_images_in_folder(image_dir) path_ref_images = utils.list_images_in_folder(resample_like_dir) - assert len(path_images) == len(path_ref_images), \ - 'the folders containing the images and their references are not the same size' + assert len(path_images) == len( + path_ref_images + ), "the folders containing the images and their references are not the same size" # loop over images - loop_info = utils.LoopInfo(len(path_images), 10, 'upsampling', True) + loop_info = utils.LoopInfo(len(path_images), 10, "upsampling", True) for idx, (path_image, path_ref) in enumerate(zip(path_images, path_ref_images)): loop_info.update(idx) # upsample image - _, _, n_dims, _, _, image_res = utils.get_volume_info(path_image, return_volume=False) - path_im_upsampled = os.path.join(resample_image_result_dir, os.path.basename(path_image)) + _, _, n_dims, _, _, image_res = utils.get_volume_info( + path_image, return_volume=False + ) + path_im_upsampled = os.path.join( + resample_image_result_dir, os.path.basename(path_image) + ) if (not os.path.isfile(path_im_upsampled)) | recompute: - cmd = utils.mkcmd(mri_convert, path_image, path_im_upsampled, '-rl', path_ref, '-odt float') + cmd = utils.mkcmd( + mri_convert, + path_image, + path_im_upsampled, + "-rl", + path_ref, + "-odt float", + ) os.system(cmd) - path_dist_map = os.path.join(resample_image_result_dir, 'dist_map_' + os.path.basename(path_image)) + path_dist_map = os.path.join( + resample_image_result_dir, "dist_map_" + os.path.basename(path_image) + ) if (not os.path.isfile(path_dist_map)) | recompute: im, aff, h = utils.load_volume(path_image, im_only=False) - dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing='ij') - tmp_dir = utils.strip_extension(path_im_upsampled) + '_meshes' + dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing="ij") + tmp_dir = utils.strip_extension(path_im_upsampled) + "_meshes" utils.mkdir(tmp_dir) path_meshes_up = list() - for (i, maps) in enumerate(dist_map): - path_mesh = os.path.join(tmp_dir, '%s_' % i + os.path.basename(path_image)) - path_mesh_up = os.path.join(tmp_dir, 'up_%s_' % i + os.path.basename(path_image)) + for i, maps in enumerate(dist_map): + path_mesh = os.path.join( + tmp_dir, "%s_" % i + os.path.basename(path_image) + ) + path_mesh_up = os.path.join( + tmp_dir, "up_%s_" % i + os.path.basename(path_image) + ) utils.save_volume(maps, aff, h, path_mesh) - cmd = utils.mkcmd(mri_convert, path_mesh, path_mesh_up, '-rl', path_im_upsampled, '-odt float') + cmd = utils.mkcmd( + mri_convert, + path_mesh, + path_mesh_up, + "-rl", + path_im_upsampled, + "-odt float", + ) os.system(cmd) path_meshes_up.append(path_mesh_up) mesh_up_0, aff, h = utils.load_volume(path_meshes_up[0], im_only=False) - mesh_up = np.stack([mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1) + mesh_up = np.stack( + [mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1 + ) shutil.rmtree(tmp_dir) floor = np.floor(mesh_up) ceil = np.ceil(mesh_up) f_dist = mesh_up - floor c_dist = ceil - mesh_up - dist = np.minimum(f_dist, c_dist) * utils.add_axis(image_res, axis=[0] * n_dims) - dist = np.sqrt(np.sum(dist ** 2, axis=-1)) + dist = np.minimum(f_dist, c_dist) * utils.add_axis( + image_res, axis=[0] * n_dims + ) + dist = np.sqrt(np.sum(dist**2, axis=-1)) utils.save_volume(dist, aff, h, path_dist_map) -def simulate_upsampled_anisotropic_images(image_dir, - downsample_image_result_dir, - resample_image_result_dir, - data_res, - labels_dir=None, - downsample_labels_result_dir=None, - slice_thickness=None, - build_dist_map=False, - path_freesurfer='/usr/local/freesurfer/', - gpu=True, - recompute=True): +def simulate_upsampled_anisotropic_images( + image_dir, + downsample_image_result_dir, + resample_image_result_dir, + data_res, + labels_dir=None, + downsample_labels_result_dir=None, + slice_thickness=None, + build_dist_map=False, + path_freesurfer="/usr/local/freesurfer/", + gpu=True, + recompute=True, +): """This function takes as input a set of HR images and creates two datasets with it: 1) a set of LR images obtained by downsampling the HR images with nearest neighbour interpolation, 2) a set of HR images obtained by resampling the LR images to native HR with linear interpolation. @@ -1835,39 +2221,58 @@ def simulate_upsampled_anisotropic_images(image_dir, utils.mkdir(resample_image_result_dir) utils.mkdir(downsample_image_result_dir) if labels_dir is not None: - assert downsample_labels_result_dir is not None, \ - 'downsample_labels_result_dir should not be None if labels_dir is specified' + assert ( + downsample_labels_result_dir is not None + ), "downsample_labels_result_dir should not be None if labels_dir is specified" utils.mkdir(downsample_labels_result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") # list images and labels path_images = utils.list_images_in_folder(image_dir) - path_labels = [None] * len(path_images) if labels_dir is None else utils.list_images_in_folder(labels_dir) + path_labels = ( + [None] * len(path_images) + if labels_dir is None + else utils.list_images_in_folder(labels_dir) + ) # initialisation - _, _, n_dims, _, _, image_res = utils.get_volume_info(path_images[0], return_volume=False, aff_ref=np.eye(4)) - data_res = np.squeeze(utils.reformat_to_n_channels_array(data_res, n_dims, n_channels=1)) + _, _, n_dims, _, _, image_res = utils.get_volume_info( + path_images[0], return_volume=False, aff_ref=np.eye(4) + ) + data_res = np.squeeze( + utils.reformat_to_n_channels_array(data_res, n_dims, n_channels=1) + ) slice_thickness = utils.reformat_to_list(slice_thickness, length=n_dims) # loop over images previous_model_input_shape = None model = None - loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) for idx, (path_image, path_labels) in enumerate(zip(path_images, path_labels)): loop_info.update(idx) # downsample image - path_im_downsampled = os.path.join(downsample_image_result_dir, os.path.basename(path_image)) + path_im_downsampled = os.path.join( + downsample_image_result_dir, os.path.basename(path_image) + ) if (not os.path.isfile(path_im_downsampled)) | recompute: - im, _, aff, n_dims, _, h, image_res = utils.get_volume_info(path_image, return_volume=True) - im, aff_aligned = align_volume_to_ref(im, aff, aff_ref=np.eye(4), return_aff=True, n_dims=n_dims) + im, _, aff, n_dims, _, h, image_res = utils.get_volume_info( + path_image, return_volume=True + ) + im, aff_aligned = align_volume_to_ref( + im, aff, aff_ref=np.eye(4), return_aff=True, n_dims=n_dims + ) im_shape = list(im.shape[:n_dims]) - sigma = blurring_sigma_for_downsampling(image_res, data_res, thickness=slice_thickness) - sigma = [0 if data_res[i] == image_res[i] else sigma[i] for i in range(n_dims)] + sigma = blurring_sigma_for_downsampling( + image_res, data_res, thickness=slice_thickness + ) + sigma = [ + 0 if data_res[i] == image_res[i] else sigma[i] for i in range(n_dims) + ] # blur image if gpu: @@ -1882,57 +2287,100 @@ def simulate_upsampled_anisotropic_images(image_dir, utils.save_volume(im, aff_aligned, h, path_im_downsampled) # downsample blurred image - voxsize = ' '.join([str(r) for r in data_res]) - cmd = utils.mkcmd(mri_convert, path_im_downsampled, path_im_downsampled, '--voxsize', voxsize, - '-odt float -rt nearest') + voxsize = " ".join([str(r) for r in data_res]) + cmd = utils.mkcmd( + mri_convert, + path_im_downsampled, + path_im_downsampled, + "--voxsize", + voxsize, + "-odt float -rt nearest", + ) os.system(cmd) # downsample labels if necessary if path_labels is not None: - path_lab_downsampled = os.path.join(downsample_labels_result_dir, os.path.basename(path_labels)) + path_lab_downsampled = os.path.join( + downsample_labels_result_dir, os.path.basename(path_labels) + ) if (not os.path.isfile(path_lab_downsampled)) | recompute: - cmd = utils.mkcmd(mri_convert, path_labels, path_lab_downsampled, '-rl', path_im_downsampled, - '-odt float -rt nearest') + cmd = utils.mkcmd( + mri_convert, + path_labels, + path_lab_downsampled, + "-rl", + path_im_downsampled, + "-odt float -rt nearest", + ) os.system(cmd) # upsample image - path_im_upsampled = os.path.join(resample_image_result_dir, os.path.basename(path_image)) + path_im_upsampled = os.path.join( + resample_image_result_dir, os.path.basename(path_image) + ) if (not os.path.isfile(path_im_upsampled)) | recompute: - cmd = utils.mkcmd(mri_convert, path_im_downsampled, path_im_upsampled, '-rl', path_image, '-odt float') + cmd = utils.mkcmd( + mri_convert, + path_im_downsampled, + path_im_upsampled, + "-rl", + path_image, + "-odt float", + ) os.system(cmd) if build_dist_map: - path_dist_map = os.path.join(resample_image_result_dir, 'dist_map_' + os.path.basename(path_image)) + path_dist_map = os.path.join( + resample_image_result_dir, "dist_map_" + os.path.basename(path_image) + ) if (not os.path.isfile(path_dist_map)) | recompute: im, aff, h = utils.load_volume(path_im_downsampled, im_only=False) - dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing='ij') - tmp_dir = utils.strip_extension(path_im_downsampled) + '_meshes' + dist_map = np.meshgrid(*[np.arange(s) for s in im.shape], indexing="ij") + tmp_dir = utils.strip_extension(path_im_downsampled) + "_meshes" utils.mkdir(tmp_dir) path_meshes_up = list() - for (i, d_map) in enumerate(dist_map): - path_mesh = os.path.join(tmp_dir, '%s_' % i + os.path.basename(path_image)) - path_mesh_up = os.path.join(tmp_dir, 'up_%s_' % i + os.path.basename(path_image)) + for i, d_map in enumerate(dist_map): + path_mesh = os.path.join( + tmp_dir, "%s_" % i + os.path.basename(path_image) + ) + path_mesh_up = os.path.join( + tmp_dir, "up_%s_" % i + os.path.basename(path_image) + ) utils.save_volume(d_map, aff, h, path_mesh) - cmd = utils.mkcmd(mri_convert, path_mesh, path_mesh_up, '-rl', path_image, '-odt float') + cmd = utils.mkcmd( + mri_convert, + path_mesh, + path_mesh_up, + "-rl", + path_image, + "-odt float", + ) os.system(cmd) path_meshes_up.append(path_mesh_up) mesh_up_0, aff, h = utils.load_volume(path_meshes_up[0], im_only=False) - mesh_up = np.stack([mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1) + mesh_up = np.stack( + [mesh_up_0] + [utils.load_volume(p) for p in path_meshes_up[1:]], -1 + ) shutil.rmtree(tmp_dir) floor = np.floor(mesh_up) ceil = np.ceil(mesh_up) f_dist = mesh_up - floor c_dist = ceil - mesh_up - dist = np.minimum(f_dist, c_dist) * utils.add_axis(data_res, axis=[0] * n_dims) - dist = np.sqrt(np.sum(dist ** 2, axis=-1)) + dist = np.minimum(f_dist, c_dist) * utils.add_axis( + data_res, axis=[0] * n_dims + ) + dist = np.sqrt(np.sum(dist**2, axis=-1)) utils.save_volume(dist, aff, h, path_dist_map) -def check_images_in_dir(image_dir, check_values=False, keep_unique=True, max_channels=10, verbose=True): +def check_images_in_dir( + image_dir, check_values=False, keep_unique=True, max_channels=10, verbose=True +): """Check if all volumes within the same folder share the same characteristics: shape, affine matrix, resolution. Also have option to check if all volumes have the same intensity values (useful for label maps). - :return four lists, each containing the different values detected for a specific parameter among those to check.""" + :return four lists, each containing the different values detected for a specific parameter among those to check. + """ # define information to check list_shape = list() @@ -1946,13 +2394,17 @@ def check_images_in_dir(image_dir, check_values=False, keep_unique=True, max_cha # loop through files path_images = utils.list_images_in_folder(image_dir) - loop_info = utils.LoopInfo(len(path_images), 10, 'checking', verbose) if verbose else None + loop_info = ( + utils.LoopInfo(len(path_images), 10, "checking", verbose) if verbose else None + ) for idx, path_image in enumerate(path_images): if loop_info is not None: loop_info.update(idx) # get info - im, shape, aff, n_dims, _, h, res = utils.get_volume_info(path_image, True, np.eye(4), max_channels) + im, shape, aff, n_dims, _, h, res = utils.get_volume_info( + path_image, True, np.eye(4), max_channels + ) axes = get_ras_axes(aff, n_dims=n_dims).tolist() aff[:, np.arange(n_dims)] = aff[:, axes] aff = (np.int32(np.round(np.array(aff[:3, :3]), 2) * 100) / 100).tolist() @@ -1977,8 +2429,17 @@ def check_images_in_dir(image_dir, check_values=False, keep_unique=True, max_cha # ----------------------------------------------- edit label maps in dir ----------------------------------------------- -def correct_labels_in_dir(labels_dir, results_dir, incorrect_labels, correct_labels=None, - use_nearest_label=False, remove_zero=False, smooth=False, recompute=True): + +def correct_labels_in_dir( + labels_dir, + results_dir, + incorrect_labels, + correct_labels=None, + use_nearest_label=False, + remove_zero=False, + smooth=False, + recompute=True, +): """This function corrects label values for all label maps in a folder with either - a list a given values, - or with the nearest label value. @@ -2002,19 +2463,33 @@ def correct_labels_in_dir(labels_dir, results_dir, incorrect_labels, correct_lab # prepare data files path_labels = utils.list_images_in_folder(labels_dir) - loop_info = utils.LoopInfo(len(path_labels), 10, 'correcting', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "correcting", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # correct labels path_result = os.path.join(results_dir, os.path.basename(path_label)) if (not os.path.isfile(path_result)) | recompute: - im, aff, h = utils.load_volume(path_label, im_only=False, dtype='int32') - im = correct_label_map(im, incorrect_labels, correct_labels, use_nearest_label, remove_zero, smooth) + im, aff, h = utils.load_volume(path_label, im_only=False, dtype="int32") + im = correct_label_map( + im, + incorrect_labels, + correct_labels, + use_nearest_label, + remove_zero, + smooth, + ) utils.save_volume(im, aff, h, path_result) -def mask_labels_in_dir(labels_dir, result_dir, values_to_keep, masking_value=0, mask_result_dir=None, recompute=True): +def mask_labels_in_dir( + labels_dir, + result_dir, + values_to_keep, + masking_value=0, + mask_result_dir=None, + recompute=True, +): """This function masks all label maps in a folder by keeping a set of given label values. :param labels_dir: path of directory with input label maps :param result_dir: path of directory where corrected label maps will be writen @@ -2034,30 +2509,42 @@ def mask_labels_in_dir(labels_dir, result_dir, values_to_keep, masking_value=0, # loop over labels path_labels = utils.list_images_in_folder(labels_dir) - loop_info = utils.LoopInfo(len(path_labels), 10, 'masking', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "masking", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # mask labels path_result = os.path.join(result_dir, os.path.basename(path_label)) if mask_result_dir is not None: - path_result_mask = os.path.join(mask_result_dir, os.path.basename(path_label)) + path_result_mask = os.path.join( + mask_result_dir, os.path.basename(path_label) + ) else: - path_result_mask = '' - if (not os.path.isfile(path_result)) | \ - (mask_result_dir is not None) & (not os.path.isfile(path_result_mask)) | \ - recompute: + path_result_mask = "" + if ( + (not os.path.isfile(path_result)) + | (mask_result_dir is not None) & (not os.path.isfile(path_result_mask)) + | recompute + ): lab, aff, h = utils.load_volume(path_label, im_only=False) if mask_result_dir is not None: - labels, mask = mask_label_map(lab, values_to_keep, masking_value, return_mask=True) - path_result_mask = os.path.join(mask_result_dir, os.path.basename(path_label)) + labels, mask = mask_label_map( + lab, values_to_keep, masking_value, return_mask=True + ) + path_result_mask = os.path.join( + mask_result_dir, os.path.basename(path_label) + ) utils.save_volume(mask, aff, h, path_result_mask) else: - labels = mask_label_map(lab, values_to_keep, masking_value, return_mask=False) + labels = mask_label_map( + lab, values_to_keep, masking_value, return_mask=False + ) utils.save_volume(labels, aff, h, path_result) -def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, connectivity=1, recompute=True): +def smooth_labels_in_dir( + labels_dir, result_dir, gpu=False, labels_list=None, connectivity=1, recompute=True +): """Smooth all label maps in a folder by replacing each voxel by the value of its most numerous neighbours. :param labels_dir: path of directory with input label maps :param result_dir: path of directory where smoothed label maps will be writen @@ -2083,28 +2570,40 @@ def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, co smoothing_model = None # loop over label maps - loop_info = utils.LoopInfo(len(path_labels), 10, 'smoothing', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "smoothing", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # smooth label map path_result = os.path.join(result_dir, os.path.basename(path_label)) if (not os.path.isfile(path_result)) | recompute: - labels, label_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_label, return_volume=True) + labels, label_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_label, return_volume=True + ) if label_shape != previous_model_input_shape: previous_model_input_shape = label_shape - smoothing_model = smoothing_gpu_model(label_shape, labels_list, connectivity) - unique_labels = np.unique(labels).astype('int32') + smoothing_model = smoothing_gpu_model( + label_shape, labels_list, connectivity + ) + unique_labels = np.unique(labels).astype("int32") if labels_list is None: smoothed_labels = smoothing_model.predict(utils.add_axis(labels)) else: - labels_to_keep = [lab for lab in unique_labels if lab not in labels_list] - new_labels, mask_new_labels = mask_label_map(labels, labels_to_keep, return_mask=True) - smoothed_labels = np.squeeze(smoothing_model.predict(utils.add_axis(labels))) - smoothed_labels = np.where(mask_new_labels, new_labels, smoothed_labels) + labels_to_keep = [ + lab for lab in unique_labels if lab not in labels_list + ] + new_labels, mask_new_labels = mask_label_map( + labels, labels_to_keep, return_mask=True + ) + smoothed_labels = np.squeeze( + smoothing_model.predict(utils.add_axis(labels)) + ) + smoothed_labels = np.where( + mask_new_labels, new_labels, smoothed_labels + ) mask_new_zeros = (labels > 0) & (smoothed_labels == 0) smoothed_labels[mask_new_zeros] = labels[mask_new_zeros] - utils.save_volume(smoothed_labels, aff, h, path_result, dtype='int32') + utils.save_volume(smoothed_labels, aff, h, path_result, dtype="int32") else: # build kernel @@ -2112,7 +2611,7 @@ def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, co kernel = utils.build_binary_structure(connectivity, n_dims, shape=n_dims) # loop over label maps - loop_info = utils.LoopInfo(len(path_labels), 10, 'smoothing', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "smoothing", True) for idx, path in enumerate(path_labels): loop_info.update(idx) @@ -2121,7 +2620,7 @@ def smooth_labels_in_dir(labels_dir, result_dir, gpu=False, labels_list=None, co if (not os.path.isfile(path_result)) | recompute: volume, aff, h = utils.load_volume(path, im_only=False) new_volume = smooth_label_map(volume, kernel, labels_list) - utils.save_volume(new_volume, aff, h, path_result, dtype='int32') + utils.save_volume(new_volume, aff, h, path_result, dtype="int32") def smoothing_gpu_model(label_shape, label_list, connectivity=1): @@ -2135,18 +2634,26 @@ def smoothing_gpu_model(label_shape, label_list, connectivity=1): # convert labels so values are in [0, ..., N-1] and use one hot encoding n_labels = label_list.shape[0] - labels_in = KL.Input(shape=label_shape, name='lab_input', dtype='int32') + labels_in = KL.Input(shape=label_shape, name="lab_input", dtype="int32") labels = ConvertLabels(label_list)(labels_in) - labels = KL.Lambda(lambda x: tf.one_hot(tf.cast(x, dtype='int32'), depth=n_labels, axis=-1))(labels) + labels = KL.Lambda( + lambda x: tf.one_hot(tf.cast(x, dtype="int32"), depth=n_labels, axis=-1) + )(labels) # count neighbouring voxels n_dims, _ = utils.get_dims(label_shape) - k = utils.add_axis(utils.build_binary_structure(connectivity, n_dims, shape=n_dims), axis=[-1, -1]) - kernel = KL.Lambda(lambda x: tf.convert_to_tensor(k, dtype='float32'))([]) + k = utils.add_axis( + utils.build_binary_structure(connectivity, n_dims, shape=n_dims), axis=[-1, -1] + ) + kernel = KL.Lambda(lambda x: tf.convert_to_tensor(k, dtype="float32"))([]) split = KL.Lambda(lambda x: tf.split(x, [1] * n_labels, axis=-1))(labels) - labels = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding='SAME'))([split[0], kernel]) + labels = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding="SAME"))( + [split[0], kernel] + ) for i in range(1, n_labels): - tmp = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding='SAME'))([split[i], kernel]) + tmp = KL.Lambda(lambda x: tf.nn.convolution(x[0], x[1], padding="SAME"))( + [split[i], kernel] + ) labels = KL.Lambda(lambda x: tf.concat([x[0], x[1]], -1))([labels, tmp]) # take the argmax and convert labels to original values @@ -2155,7 +2662,14 @@ def smoothing_gpu_model(label_shape, label_list, connectivity=1): return Model(inputs=labels_in, outputs=labels) -def erode_labels_in_dir(labels_dir, result_dir, labels_to_erode, erosion_factors=1., gpu=False, recompute=True): +def erode_labels_in_dir( + labels_dir, + result_dir, + labels_to_erode, + erosion_factors=1.0, + gpu=False, + recompute=True, +): """Erode a given set of label values for all label maps in a folder. :param labels_dir: path of directory with input label maps :param result_dir: path of directory where cropped label maps will be writen @@ -2173,7 +2687,7 @@ def erode_labels_in_dir(labels_dir, result_dir, labels_to_erode, erosion_factors # loop over label maps model = None path_labels = utils.list_images_in_folder(labels_dir) - loop_info = utils.LoopInfo(len(path_labels), 5, 'eroding', True) + loop_info = utils.LoopInfo(len(path_labels), 5, "eroding", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) @@ -2181,16 +2695,20 @@ def erode_labels_in_dir(labels_dir, result_dir, labels_to_erode, erosion_factors labels, aff, h = utils.load_volume(path_label, im_only=False) path_result = os.path.join(result_dir, os.path.basename(path_label)) if (not os.path.isfile(path_result)) | recompute: - labels, model = erode_label_map(labels, labels_to_erode, erosion_factors, gpu, model, return_model=True) + labels, model = erode_label_map( + labels, labels_to_erode, erosion_factors, gpu, model, return_model=True + ) utils.save_volume(labels, aff, h, path_result) -def upsample_labels_in_dir(labels_dir, - target_res, - result_dir, - path_label_list=None, - path_freesurfer='/usr/local/freesurfer/', - recompute=True): +def upsample_labels_in_dir( + labels_dir, + target_res, + result_dir, + path_label_list=None, + path_freesurfer="/usr/local/freesurfer/", + recompute=True, +): """This function upsamples all label maps within a folder. Importantly, each label map is converted into probability maps for all label values, and all these maps are upsampled separately. The upsampled label maps are recovered by taking the argmax of the label values probability maps. @@ -2207,25 +2725,29 @@ def upsample_labels_in_dir(labels_dir, utils.mkdir(result_dir) # set up FreeSurfer - os.environ['FREESURFER_HOME'] = path_freesurfer - os.system(os.path.join(path_freesurfer, 'SetUpFreeSurfer.sh')) - mri_convert = os.path.join(path_freesurfer, 'bin/mri_convert') + os.environ["FREESURFER_HOME"] = path_freesurfer + os.system(os.path.join(path_freesurfer, "SetUpFreeSurfer.sh")) + mri_convert = os.path.join(path_freesurfer, "bin/mri_convert") # list label maps path_labels = utils.list_images_in_folder(labels_dir) - labels_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_labels[0], max_channels=3) + labels_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_labels[0], max_channels=3 + ) # build command target_res = utils.reformat_to_list(target_res, length=n_dims) - post_cmd = '-voxsize ' + ' '.join([str(r) for r in target_res]) + ' -odt float' + post_cmd = "-voxsize " + " ".join([str(r) for r in target_res]) + " -odt float" # load label list and corresponding LUT to make sure that labels go from 0 to N-1 - label_list, _ = utils.get_list_labels(path_label_list, labels_dir=path_labels, FS_sort=False) - new_label_list = np.arange(len(label_list), dtype='int32') + label_list, _ = utils.get_list_labels( + path_label_list, labels_dir=path_labels, FS_sort=False + ) + new_label_list = np.arange(len(label_list), dtype="int32") lut = utils.get_mapping_lut(label_list) # loop over label maps - loop_info = utils.LoopInfo(len(path_labels), 5, 'upsampling', True) + loop_info = utils.LoopInfo(len(path_labels), 5, "upsampling", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) path_result = os.path.join(result_dir, os.path.basename(path_label)) @@ -2233,44 +2755,56 @@ def upsample_labels_in_dir(labels_dir, # load volume labels, aff, h = utils.load_volume(path_label, im_only=False) - labels = lut[labels.astype('int')] + labels = lut[labels.astype("int")] # create individual folders for label map basefilename = utils.strip_extension(os.path.basename(path_label)) indiv_label_dir = os.path.join(result_dir, basefilename) - upsample_indiv_label_dir = os.path.join(result_dir, basefilename + '_upsampled') + upsample_indiv_label_dir = os.path.join( + result_dir, basefilename + "_upsampled" + ) utils.mkdir(indiv_label_dir) utils.mkdir(upsample_indiv_label_dir) # loop over label values for label in new_label_list: - path_mask = os.path.join(indiv_label_dir, str(label) + '.nii.gz') - path_mask_upsampled = os.path.join(upsample_indiv_label_dir, str(label) + '.nii.gz') + path_mask = os.path.join(indiv_label_dir, str(label) + ".nii.gz") + path_mask_upsampled = os.path.join( + upsample_indiv_label_dir, str(label) + ".nii.gz" + ) if not os.path.isfile(path_mask): mask = (labels == label) * 1.0 utils.save_volume(mask, aff, h, path_mask) if not os.path.isfile(path_mask_upsampled): - cmd = utils.mkcmd(mri_convert, path_mask, path_mask_upsampled, post_cmd) + cmd = utils.mkcmd( + mri_convert, path_mask, path_mask_upsampled, post_cmd + ) os.system(cmd) # compute argmax of upsampled probability maps (upload them one at a time) - probmax, aff, h = utils.load_volume(os.path.join(upsample_indiv_label_dir, '0.nii.gz'), im_only=False) - labels = np.zeros(probmax.shape, dtype='int') + probmax, aff, h = utils.load_volume( + os.path.join(upsample_indiv_label_dir, "0.nii.gz"), im_only=False + ) + labels = np.zeros(probmax.shape, dtype="int") for label in new_label_list: - prob = utils.load_volume(os.path.join(upsample_indiv_label_dir, str(label) + '.nii.gz')) + prob = utils.load_volume( + os.path.join(upsample_indiv_label_dir, str(label) + ".nii.gz") + ) idx = prob > probmax labels[idx] = label probmax[idx] = prob[idx] - utils.save_volume(label_list[labels], aff, h, path_result, dtype='int32') - - -def compute_hard_volumes_in_dir(labels_dir, - voxel_volume=None, - path_label_list=None, - skip_background=True, - path_numpy_result=None, - path_csv_result=None, - FS_sort=False): + utils.save_volume(label_list[labels], aff, h, path_result, dtype="int32") + + +def compute_hard_volumes_in_dir( + labels_dir, + voxel_volume=None, + path_label_list=None, + skip_background=True, + path_numpy_result=None, + path_csv_result=None, + FS_sort=False, +): """Compute hard volumes of structures for all label maps in a folder. :param labels_dir: path of directory with input label maps :param voxel_volume: (optional) volume of the voxels. If None, it will be directly inferred from the file header. @@ -2299,10 +2833,10 @@ def compute_hard_volumes_in_dir(labels_dir, # create csv volume file if necessary if path_csv_result is not None: if skip_background: - cvs_header = [['subject'] + [str(lab) for lab in label_list[1:]]] + cvs_header = [["subject"] + [str(lab) for lab in label_list[1:]]] else: - cvs_header = [['subject'] + [str(lab) for lab in label_list]] - with open(path_csv_result, 'w') as csvFile: + cvs_header = [["subject"] + [str(lab) for lab in label_list]] + with open(path_csv_result, "w") as csvFile: writer = csv.writer(csvFile) writer.writerows(cvs_header) csvFile.close() @@ -2313,22 +2847,28 @@ def compute_hard_volumes_in_dir(labels_dir, volumes = np.zeros((label_list.shape[0] - 1, len(path_labels))) else: volumes = np.zeros((label_list.shape[0], len(path_labels))) - loop_info = utils.LoopInfo(len(path_labels), 10, 'processing', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "processing", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # load segmentation, and compute unique labels - labels, _, _, _, _, _, subject_res = utils.get_volume_info(path_label, return_volume=True) + labels, _, _, _, _, _, subject_res = utils.get_volume_info( + path_label, return_volume=True + ) if voxel_volume is None: voxel_volume = float(np.prod(subject_res)) - subject_volumes = compute_hard_volumes(labels, voxel_volume, label_list, skip_background) + subject_volumes = compute_hard_volumes( + labels, voxel_volume, label_list, skip_background + ) volumes[:, idx] = subject_volumes # write volumes if path_csv_result is not None: subject_volumes = np.around(volumes[:, idx], 3) - row = [utils.strip_suffix(os.path.basename(path_label))] + [str(vol) for vol in subject_volumes] - with open(path_csv_result, 'a') as csvFile: + row = [utils.strip_suffix(os.path.basename(path_label))] + [ + str(vol) for vol in subject_volumes + ] + with open(path_csv_result, "a") as csvFile: writer = csv.writer(csvFile) writer.writerow(row) csvFile.close() @@ -2340,12 +2880,14 @@ def compute_hard_volumes_in_dir(labels_dir, return volumes -def build_atlas(labels_dir, - label_list, - align_centre_of_mass=False, - margin=15, - shape=None, - path_atlas=None): +def build_atlas( + labels_dir, + label_list, + align_centre_of_mass=False, + margin=15, + shape=None, + path_atlas=None, +): """This function builds a binary atlas (defined by label values > 0) from several label maps. :param labels_dir: path of directory with input label maps :param label_list: list of all labels in the label maps. If there is more than 1 value here, the different channels @@ -2364,28 +2906,36 @@ def build_atlas(labels_dir, utils.mkdir(os.path.dirname(path_atlas)) # read list labels and create lut - label_list = np.array(utils.reformat_to_list(label_list, load_as_numpy=True, dtype='int')) + label_list = np.array( + utils.reformat_to_list(label_list, load_as_numpy=True, dtype="int") + ) lut = utils.get_mapping_lut(label_list) n_labels = len(label_list) # create empty atlas - im_shape, aff, n_dims, _, h, _ = utils.get_volume_info(path_labels[0], aff_ref=np.eye(4)) + im_shape, aff, n_dims, _, h, _ = utils.get_volume_info( + path_labels[0], aff_ref=np.eye(4) + ) if align_centre_of_mass: shape = [margin * 2] * n_dims else: - shape = utils.reformat_to_list(shape, length=n_dims) if shape is not None else im_shape + shape = ( + utils.reformat_to_list(shape, length=n_dims) + if shape is not None + else im_shape + ) shape = shape + [n_labels] if n_labels > 1 else shape atlas = np.zeros(shape) # loop over label maps - loop_info = utils.LoopInfo(n_label_maps, 10, 'processing', True) + loop_info = utils.LoopInfo(n_label_maps, 10, "processing", True) for idx, path_label in enumerate(path_labels): loop_info.update(idx) # load label map and build mask - lab = utils.load_volume(path_label, dtype='int32', aff_ref=np.eye(4)) + lab = utils.load_volume(path_label, dtype="int32", aff_ref=np.eye(4)) lab = correct_label_map(lab, [31, 63, 72], [4, 43, 0]) - lab = lut[lab.astype('int')] + lab = lut[lab.astype("int")] lab = pad_volume(lab, shape[:n_dims]) lab = crop_volume(lab, cropping_shape=shape[:n_dims]) indices = np.where(lab > 0) @@ -2395,10 +2945,18 @@ def build_atlas(labels_dir, # crop label map around centre of mass if align_centre_of_mass: - centre_of_mass = np.array([np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])], dtype='int32') + centre_of_mass = np.array( + [np.mean(indices[0]), np.mean(indices[1]), np.mean(indices[2])], + dtype="int32", + ) min_crop = centre_of_mass - margin max_crop = centre_of_mass + margin - atlas += lab[min_crop[0]:max_crop[0], min_crop[1]:max_crop[1], min_crop[2]:max_crop[2], ...] + atlas += lab[ + min_crop[0] : max_crop[0], + min_crop[1] : max_crop[1], + min_crop[2] : max_crop[2], + ..., + ] # otherwise just add the one-hot labels else: atlas += lab @@ -2414,6 +2972,7 @@ def build_atlas(labels_dir, # ---------------------------------------------------- edit dataset ---------------------------------------------------- + def check_images_and_labels(image_dir, labels_dir, verbose=True): """Check if corresponding images and labels have the same affine matrices and shapes. Labels are matched to images by sorting order. @@ -2425,10 +2984,14 @@ def check_images_and_labels(image_dir, labels_dir, verbose=True): # list images and labels path_images = utils.list_images_in_folder(image_dir) path_labels = utils.list_images_in_folder(labels_dir) - assert len(path_images) == len(path_labels), 'different number of files in image_dir and labels_dir' + assert len(path_images) == len( + path_labels + ), "different number of files in image_dir and labels_dir" # loop over images and labels - loop_info = utils.LoopInfo(len(path_images), 10, 'checking', verbose) if verbose else None + loop_info = ( + utils.LoopInfo(len(path_images), 10, "checking", verbose) if verbose else None + ) for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): if loop_info is not None: loop_info.update(idx) @@ -2441,20 +3004,22 @@ def check_images_and_labels(image_dir, labels_dir, verbose=True): # check matching affine and shape if aff_lab_list != aff_im_list: - print('aff mismatch :\n' + path_image) + print("aff mismatch :\n" + path_image) print(aff_im_list) print(path_label) print(aff_lab_list) - print('') + print("") if lab.shape != im.shape: - print('shape mismatch :\n' + path_image) + print("shape mismatch :\n" + path_image) print(im.shape) - print('\n' + path_label) + print("\n" + path_label) print(lab.shape) - print('') + print("") -def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_result_dir=None, margin=5): +def crop_dataset_to_minimum_size( + labels_dir, result_dir, image_dir=None, image_result_dir=None, margin=5 +): """Crop all label maps in a directory to the minimum possible common size, with a margin. This is achieved by cropping each label map individually to the minimum size, and by padding all the cropped maps to the same size (taken to be the maximum size of the cropped maps). @@ -2469,7 +3034,9 @@ def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_r # create result dir utils.mkdir(result_dir) if image_dir is not None: - assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" utils.mkdir(image_result_dir) # list labels and images @@ -2481,27 +3048,36 @@ def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_r _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) # loop over label maps for cropping - print('\ncropping labels to individual minimum size') + print("\ncropping labels to individual minimum size") maximum_size = np.zeros(n_dims) - loop_info = utils.LoopInfo(len(path_labels), 10, 'cropping', True) + loop_info = utils.LoopInfo(len(path_labels), 10, "cropping", True) for idx, (path_label, path_image) in enumerate(zip(path_labels, path_images)): loop_info.update(idx) # crop label maps and update maximum size of cropped map label, aff, h = utils.load_volume(path_label, im_only=False) label, cropping, aff = crop_volume_around_region(label, aff=aff) - utils.save_volume(label, aff, h, os.path.join(result_dir, os.path.basename(path_label))) - maximum_size = np.maximum(maximum_size, np.array(label.shape) + margin * 2) # *2 to add margin on each side + utils.save_volume( + label, aff, h, os.path.join(result_dir, os.path.basename(path_label)) + ) + maximum_size = np.maximum( + maximum_size, np.array(label.shape) + margin * 2 + ) # *2 to add margin on each side # crop images if required if path_image is not None: image, aff_im, h_im = utils.load_volume(path_image, im_only=False) image, aff_im = crop_volume_with_idx(image, cropping, aff=aff_im) - utils.save_volume(image, aff_im, h_im, os.path.join(image_result_dir, os.path.basename(path_image))) + utils.save_volume( + image, + aff_im, + h_im, + os.path.join(image_result_dir, os.path.basename(path_image)), + ) # loop over label maps for padding - print('\npadding labels to same size') - loop_info = utils.LoopInfo(len(path_labels), 10, 'padding', True) + print("\npadding labels to same size") + loop_info = utils.LoopInfo(len(path_labels), 10, "padding", True) for idx, (path_label, path_image) in enumerate(zip(path_labels, path_images)): loop_info.update(idx) @@ -2519,29 +3095,47 @@ def crop_dataset_to_minimum_size(labels_dir, result_dir, image_dir=None, image_r utils.save_volume(image, aff, h, path_result) -def crop_dataset_around_region_of_same_size(labels_dir, - result_dir, - image_dir=None, - image_result_dir=None, - margin=0, - recompute=True): +def crop_dataset_around_region_of_same_size( + labels_dir, + result_dir, + image_dir=None, + image_result_dir=None, + margin=0, + recompute=True, +): # create result dir utils.mkdir(result_dir) if image_dir is not None: - assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" utils.mkdir(image_result_dir) # list labels and images path_labels = utils.list_images_in_folder(labels_dir) - path_images = utils.list_images_in_folder(image_dir) if image_dir is not None else [None] * len(path_labels) + path_images = ( + utils.list_images_in_folder(image_dir) + if image_dir is not None + else [None] * len(path_labels) + ) _, _, n_dims, _, _, _ = utils.get_volume_info(path_labels[0]) - recompute_labels = any([not os.path.isfile(os.path.join(result_dir, os.path.basename(path))) - for path in path_labels]) + recompute_labels = any( + [ + not os.path.isfile(os.path.join(result_dir, os.path.basename(path))) + for path in path_labels + ] + ) if (image_dir is not None) & (not recompute_labels): - recompute_labels = any([not os.path.isfile(os.path.join(image_result_dir, os.path.basename(path))) - for path in path_images]) + recompute_labels = any( + [ + not os.path.isfile( + os.path.join(image_result_dir, os.path.basename(path)) + ) + for path in path_images + ] + ) # get minimum patch shape so that no labels are left out when doing the cropping later on max_crop_shape = np.zeros(n_dims) @@ -2549,11 +3143,17 @@ def crop_dataset_around_region_of_same_size(labels_dir, for path_label in path_labels: label, aff, _ = utils.load_volume(path_label, im_only=False) label = align_volume_to_ref(label, aff, aff_ref=np.eye(4)) - label = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + label = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) _, cropping = crop_volume_around_region(label) - max_crop_shape = np.maximum(cropping[n_dims:] - cropping[:n_dims], max_crop_shape) - max_crop_shape += np.array(utils.reformat_to_list(margin, length=n_dims, dtype='int')) - print('max_crop_shape: ', max_crop_shape) + max_crop_shape = np.maximum( + cropping[n_dims:] - cropping[:n_dims], max_crop_shape + ) + max_crop_shape += np.array( + utils.reformat_to_list(margin, length=n_dims, dtype="int") + ) + print("max_crop_shape: ", max_crop_shape) # crop shapes (possibly with padding if images are smaller than crop shape) for path_label, path_image in zip(path_labels, path_images): @@ -2561,10 +3161,18 @@ def crop_dataset_around_region_of_same_size(labels_dir, path_label_result = os.path.join(result_dir, os.path.basename(path_label)) path_image_result = os.path.join(image_result_dir, os.path.basename(path_image)) - if (not os.path.isfile(path_image_result)) | (not os.path.isfile(path_label_result)) | recompute: + if ( + (not os.path.isfile(path_image_result)) + | (not os.path.isfile(path_label_result)) + | recompute + ): # load labels - label, aff, h_la = utils.load_volume(path_label, im_only=False, dtype='int32') - label, aff_new = align_volume_to_ref(label, aff, aff_ref=np.eye(4), return_aff=True) + label, aff, h_la = utils.load_volume( + path_label, im_only=False, dtype="int32" + ) + label, aff_new = align_volume_to_ref( + label, aff, aff_ref=np.eye(4), return_aff=True + ) vol_shape = np.array(label.shape[:n_dims]) if path_image is not None: image, _, h_im = utils.load_volume(path_image, im_only=False) @@ -2573,27 +3181,39 @@ def crop_dataset_around_region_of_same_size(labels_dir, image = h_im = None # mask labels - mask = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + mask = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) label[np.logical_not(mask)] = 0 # find cropping indices indices = np.nonzero(mask) min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) - max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape) + max_idx = np.minimum( + np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape + ) # expand/retract (depending on the desired shape) the cropping region around the centre intermediate_vol_shape = max_idx - min_idx - min_idx = min_idx - np.int32(np.ceil((max_crop_shape - intermediate_vol_shape) / 2)) - max_idx = max_idx + np.int32(np.floor((max_crop_shape - intermediate_vol_shape) / 2)) + min_idx = min_idx - np.int32( + np.ceil((max_crop_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((max_crop_shape - intermediate_vol_shape) / 2) + ) # check if we need to pad the output to the desired shape min_padding = np.abs(np.minimum(min_idx, 0)) max_padding = np.maximum(max_idx - vol_shape, 0) if np.any(min_padding > 0) | np.any(max_padding > 0): - pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) + pad_margins = tuple( + [(min_padding[i], max_padding[i]) for i in range(n_dims)] + ) else: pad_margins = None - cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) + cropping = np.concatenate( + [np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)] + ) # crop volume label = crop_volume_with_idx(label, cropping, n_dims=n_dims) @@ -2602,11 +3222,17 @@ def crop_dataset_around_region_of_same_size(labels_dir, # pad volume if necessary if pad_margins is not None: - label = np.pad(label, pad_margins, mode='constant', constant_values=0) + label = np.pad(label, pad_margins, mode="constant", constant_values=0) if path_image is not None: _, n_channels = utils.get_dims(image.shape) - pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins - image = np.pad(image, pad_margins, mode='constant', constant_values=0) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) + if n_channels > 1 + else pad_margins + ) + image = np.pad( + image, pad_margins, mode="constant", constant_values=0 + ) # update aff if n_dims == 2: @@ -2614,15 +3240,24 @@ def crop_dataset_around_region_of_same_size(labels_dir, aff_new[0:3, -1] = aff_new[0:3, -1] + aff_new[:3, :3] @ min_idx # write labels - label, aff_final = align_volume_to_ref(label, aff_new, aff_ref=aff, return_aff=True) - utils.save_volume(label, aff_final, h_la, path_label_result, dtype='int32') + label, aff_final = align_volume_to_ref( + label, aff_new, aff_ref=aff, return_aff=True + ) + utils.save_volume(label, aff_final, h_la, path_label_result, dtype="int32") if path_image is not None: image = align_volume_to_ref(image, aff_new, aff_ref=aff) utils.save_volume(image, aff_final, h_im, path_image_result) -def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_result_dir, margin=0, - cropping_shape_div_by=None, recompute=True): +def crop_dataset_around_region( + image_dir, + labels_dir, + image_result_dir, + labels_result_dir, + margin=0, + cropping_shape_div_by=None, + recompute=True, +): # create result dir utils.mkdir(image_result_dir) @@ -2634,42 +3269,65 @@ def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_r _, _, n_dims, n_channels, _, _ = utils.get_volume_info(path_labels[0]) # loop over images and labels - loop_info = utils.LoopInfo(len(path_images), 10, 'cropping', True) + loop_info = utils.LoopInfo(len(path_images), 10, "cropping", True) for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): loop_info.update(idx) - path_label_result = os.path.join(labels_result_dir, os.path.basename(path_label)) + path_label_result = os.path.join( + labels_result_dir, os.path.basename(path_label) + ) path_image_result = os.path.join(image_result_dir, os.path.basename(path_image)) - if (not os.path.isfile(path_label_result)) | (not os.path.isfile(path_image_result)) | recompute: + if ( + (not os.path.isfile(path_label_result)) + | (not os.path.isfile(path_image_result)) + | recompute + ): image, aff, h_im = utils.load_volume(path_image, im_only=False) label, _, h_lab = utils.load_volume(path_label, im_only=False) - mask = get_largest_connected_component(label > 0, structure=np.ones((3, 3, 3))) + mask = get_largest_connected_component( + label > 0, structure=np.ones((3, 3, 3)) + ) label[np.logical_not(mask)] = 0 vol_shape = np.array(label.shape[:n_dims]) # find cropping indices indices = np.nonzero(mask) min_idx = np.maximum(np.array([np.min(idx) for idx in indices]) - margin, 0) - max_idx = np.minimum(np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape) + max_idx = np.minimum( + np.array([np.max(idx) for idx in indices]) + 1 + margin, vol_shape + ) # expand/retract (depending on the desired shape) the cropping region around the centre intermediate_vol_shape = max_idx - min_idx - cropping_shape = np.array([utils.find_closest_number_divisible_by_m(s, cropping_shape_div_by, - answer_type='higher') - for s in intermediate_vol_shape]) - min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2)) - max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2)) + cropping_shape = np.array( + [ + utils.find_closest_number_divisible_by_m( + s, cropping_shape_div_by, answer_type="higher" + ) + for s in intermediate_vol_shape + ] + ) + min_idx = min_idx - np.int32( + np.ceil((cropping_shape - intermediate_vol_shape) / 2) + ) + max_idx = max_idx + np.int32( + np.floor((cropping_shape - intermediate_vol_shape) / 2) + ) # check if we need to pad the output to the desired shape min_padding = np.abs(np.minimum(min_idx, 0)) max_padding = np.maximum(max_idx - vol_shape, 0) if np.any(min_padding > 0) | np.any(max_padding > 0): - pad_margins = tuple([(min_padding[i], max_padding[i]) for i in range(n_dims)]) + pad_margins = tuple( + [(min_padding[i], max_padding[i]) for i in range(n_dims)] + ) else: pad_margins = None - cropping = np.concatenate([np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)]) + cropping = np.concatenate( + [np.maximum(min_idx, 0), np.minimum(max_idx, vol_shape)] + ) # crop volume label = crop_volume_with_idx(label, cropping, n_dims=n_dims) @@ -2677,9 +3335,13 @@ def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_r # pad volume if necessary if pad_margins is not None: - label = np.pad(label, pad_margins, mode='constant', constant_values=0) - pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins - image = np.pad(image, pad_margins, mode='constant', constant_values=0) + label = np.pad(label, pad_margins, mode="constant", constant_values=0) + pad_margins = ( + tuple(list(pad_margins) + [(0, 0)]) + if n_channels > 1 + else pad_margins + ) + image = np.pad(image, pad_margins, mode="constant", constant_values=0) # update aff if n_dims == 2: @@ -2688,16 +3350,18 @@ def crop_dataset_around_region(image_dir, labels_dir, image_result_dir, labels_r # write results utils.save_volume(image, aff, h_im, path_image_result) - utils.save_volume(label, aff, h_lab, path_label_result, dtype='int32') - - -def subdivide_dataset_to_patches(patch_shape, - image_dir=None, - image_result_dir=None, - labels_dir=None, - labels_result_dir=None, - full_background=True, - remove_after_dividing=False): + utils.save_volume(label, aff, h_lab, path_label_result, dtype="int32") + + +def subdivide_dataset_to_patches( + patch_shape, + image_dir=None, + image_result_dir=None, + labels_dir=None, + labels_result_dir=None, + full_background=True, + remove_after_dividing=False, +): """This function subdivides images and/or label maps into several smaller patches of specified shape. :param patch_shape: shape of patches to create. Can either be an int, a sequence, or a 1d numpy array. :param image_dir: (optional) path of directory with input images @@ -2711,16 +3375,21 @@ def subdivide_dataset_to_patches(patch_shape, """ # create result dir and list images and label maps - assert (image_dir is not None) | (labels_dir is not None), \ - 'at least one of image_dir or labels_dir should not be None.' + assert (image_dir is not None) | ( + labels_dir is not None + ), "at least one of image_dir or labels_dir should not be None." if image_dir is not None: - assert image_result_dir is not None, 'image_result_dir should not be None if image_dir is specified' + assert ( + image_result_dir is not None + ), "image_result_dir should not be None if image_dir is specified" utils.mkdir(image_result_dir) path_images = utils.list_images_in_folder(image_dir) else: path_images = None if labels_dir is not None: - assert labels_result_dir is not None, 'labels_result_dir should not be None if labels_dir is specified' + assert ( + labels_result_dir is not None + ), "labels_result_dir should not be None if labels_dir is specified" utils.mkdir(labels_result_dir) path_labels = utils.list_images_in_folder(labels_dir) else: @@ -2735,17 +3404,21 @@ def subdivide_dataset_to_patches(patch_shape, n_dims, _ = utils.get_dims(patch_shape) # loop over images and labels - loop_info = utils.LoopInfo(len(path_images), 10, 'processing', True) + loop_info = utils.LoopInfo(len(path_images), 10, "processing", True) for idx, (path_image, path_label) in enumerate(zip(path_images, path_labels)): loop_info.update(idx) # load image and labels if path_image is not None: - im, aff_im, h_im = utils.load_volume(path_image, im_only=False, squeeze=False) + im, aff_im, h_im = utils.load_volume( + path_image, im_only=False, squeeze=False + ) else: im = aff_im = h_im = None if path_label is not None: - lab, aff_lab, h_lab = utils.load_volume(path_label, im_only=False, squeeze=True) + lab, aff_lab, h_lab = utils.load_volume( + path_label, im_only=False, squeeze=True + ) else: lab = aff_lab = h_lab = None @@ -2756,21 +3429,26 @@ def subdivide_dataset_to_patches(patch_shape, shape = lab.shape # crop image and label map to size divisible by patch_shape - new_size = np.array([utils.find_closest_number_divisible_by_m(shape[i], patch_shape[i]) for i in range(n_dims)]) - crop = np.round((np.array(shape[:n_dims]) - new_size) / 2).astype('int') + new_size = np.array( + [ + utils.find_closest_number_divisible_by_m(shape[i], patch_shape[i]) + for i in range(n_dims) + ] + ) + crop = np.round((np.array(shape[:n_dims]) - new_size) / 2).astype("int") crop = np.concatenate((crop, crop + new_size), axis=0) if (im is not None) & (n_dims == 2): - im = im[crop[0]:crop[2], crop[1]:crop[3], ...] + im = im[crop[0] : crop[2], crop[1] : crop[3], ...] elif (im is not None) & (n_dims == 3): - im = im[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] + im = im[crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ...] if (lab is not None) & (n_dims == 2): - lab = lab[crop[0]:crop[2], crop[1]:crop[3], ...] + lab = lab[crop[0] : crop[2], crop[1] : crop[3], ...] elif (lab is not None) & (n_dims == 3): - lab = lab[crop[0]:crop[3], crop[1]:crop[4], crop[2]:crop[5], ...] + lab = lab[crop[0] : crop[3], crop[1] : crop[4], crop[2] : crop[5], ...] # loop over patches n_im = 0 - n_crop = (new_size / patch_shape).astype('int') + n_crop = (new_size / patch_shape).astype("int") for i in range(n_crop[0]): i *= patch_shape[0] for j in range(n_crop[1]): @@ -2780,11 +3458,15 @@ def subdivide_dataset_to_patches(patch_shape, # crop volumes if lab is not None: - temp_la = lab[i:i + patch_shape[0], j:j + patch_shape[1], ...] + temp_la = lab[ + i : i + patch_shape[0], j : j + patch_shape[1], ... + ] else: temp_la = None if im is not None: - temp_im = im[i:i + patch_shape[0], j:j + patch_shape[1], ...] + temp_im = im[ + i : i + patch_shape[0], j : j + patch_shape[1], ... + ] else: temp_im = None @@ -2792,14 +3474,45 @@ def subdivide_dataset_to_patches(patch_shape, if temp_la is not None: if full_background | (not (temp_la == 0).all()): n_im += 1 - utils.save_volume(temp_la, aff_lab, h_lab, os.path.join(labels_result_dir, - os.path.basename(path_label.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_la, + aff_lab, + h_lab, + os.path.join( + labels_result_dir, + os.path.basename( + path_label.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) if temp_im is not None: - utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, - os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) else: - utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, - os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace(".nii.gz", "_%d.nii.gz" % n_im) + ), + ), + ) elif n_dims == 3: for k in range(n_crop[2]): @@ -2807,11 +3520,21 @@ def subdivide_dataset_to_patches(patch_shape, # crop volumes if lab is not None: - temp_la = lab[i:i + patch_shape[0], j:j + patch_shape[1], k:k + patch_shape[2], ...] + temp_la = lab[ + i : i + patch_shape[0], + j : j + patch_shape[1], + k : k + patch_shape[2], + ..., + ] else: temp_la = None if im is not None: - temp_im = im[i:i + patch_shape[0], j:j + patch_shape[1], k:k + patch_shape[2], ...] + temp_im = im[ + i : i + patch_shape[0], + j : j + patch_shape[1], + k : k + patch_shape[2], + ..., + ] else: temp_im = None @@ -2819,15 +3542,47 @@ def subdivide_dataset_to_patches(patch_shape, if temp_la is not None: if full_background | (not (temp_la == 0).all()): n_im += 1 - utils.save_volume(temp_la, aff_lab, h_lab, os.path.join(labels_result_dir, - os.path.basename(path_label.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_la, + aff_lab, + h_lab, + os.path.join( + labels_result_dir, + os.path.basename( + path_label.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) if temp_im is not None: - utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, - os.path.basename(path_image.replace('.nii.gz', - '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) else: - utils.save_volume(temp_im, aff_im, h_im, os.path.join(image_result_dir, - os.path.basename(path_image.replace('.nii.gz', '_%d.nii.gz' % n_im)))) + utils.save_volume( + temp_im, + aff_im, + h_im, + os.path.join( + image_result_dir, + os.path.basename( + path_image.replace( + ".nii.gz", "_%d.nii.gz" % n_im + ) + ), + ), + ) if remove_after_dividing: if path_image is not None: diff --git a/nobrainer/ext/lab2im/image_generator.py b/nobrainer/ext/lab2im/image_generator.py index 073442e8..d8f83bc0 100644 --- a/nobrainer/ext/lab2im/image_generator.py +++ b/nobrainer/ext/lab2im/image_generator.py @@ -13,34 +13,34 @@ License. """ - # python imports import numpy as np import numpy.random as npr # project imports -from nobrainer.ext.lab2im import utils -from nobrainer.ext.lab2im import edit_volumes +from nobrainer.ext.lab2im import edit_volumes, utils from nobrainer.ext.lab2im.lab2im_model import lab2im_model class ImageGenerator: - def __init__(self, - labels_dir, - generation_labels=None, - output_labels=None, - batchsize=1, - n_channels=1, - target_res=None, - output_shape=None, - output_div_by_n=None, - generation_classes=None, - prior_distributions='uniform', - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - blur_range=1.15): + def __init__( + self, + labels_dir, + generation_labels=None, + output_labels=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + generation_classes=None, + prior_distributions="uniform", + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + blur_range=1.15, + ): """ This class is wrapper around the lab2im_model model. It contains the GPU model that generates images from labels maps, and a python generator that supplies the input data for this model. @@ -115,8 +115,9 @@ def __init__(self, self.labels_paths = utils.list_images_in_folder(labels_dir) # generation parameters - self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = ( utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + ) self.n_channels = n_channels if generation_labels is not None: self.generation_labels = utils.load_array_if_path(generation_labels) @@ -135,11 +136,13 @@ def __init__(self, self.prior_distributions = prior_distributions if generation_classes is not None: self.generation_classes = utils.load_array_if_path(generation_classes) - assert self.generation_classes.shape == self.generation_labels.shape, \ - 'if provided, generation labels should have the same shape as generation_labels' + assert ( + self.generation_classes.shape == self.generation_labels.shape + ), "if provided, generation labels should have the same shape as generation_labels" unique_classes = np.unique(self.generation_classes) - assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ - 'generation_classes should a linear range between 0 and its maximum value.' + assert np.array_equal( + unique_classes, np.arange(np.max(unique_classes) + 1) + ), "generation_classes should a linear range between 0 and its maximum value." else: self.generation_classes = np.arange(self.generation_labels.shape[0]) self.prior_means = utils.load_array_if_path(prior_means) @@ -153,22 +156,26 @@ def __init__(self, self.labels_to_image_model, self.model_output_shape = self._build_lab2im_model() # build generator for model inputs - self.model_inputs_generator = self._build_model_inputs(len(self.generation_labels)) + self.model_inputs_generator = self._build_model_inputs( + len(self.generation_labels) + ) # build brain generator self.image_generator = self._build_image_generator() def _build_lab2im_model(self): # build_model - lab_to_im_model = lab2im_model(labels_shape=self.labels_shape, - n_channels=self.n_channels, - generation_labels=self.generation_labels, - output_labels=self.output_labels, - atlas_res=self.atlas_res, - target_res=self.target_res, - output_shape=self.output_shape, - output_div_by_n=self.output_div_by_n, - blur_range=self.blur_range) + lab_to_im_model = lab2im_model( + labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + blur_range=self.blur_range, + ) out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] return lab_to_im_model, out_shape @@ -185,10 +192,16 @@ def generate_image(self): list_images = list() list_labels = list() for i in range(self.batchsize): - list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), aff_ref=self.aff, - n_dims=self.n_dims)) - list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), aff_ref=self.aff, - n_dims=self.n_dims)) + list_images.append( + edit_volumes.align_volume_to_ref( + image[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + list_labels.append( + edit_volumes.align_volume_to_ref( + labels[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) image = np.stack(list_images, axis=0) labels = np.stack(list_labels, axis=0) return np.squeeze(image), np.squeeze(labels) @@ -212,7 +225,9 @@ def _build_model_inputs(self, n_labels): for idx in indices: # load label in identity space, and add them to inputs - y = utils.load_volume(self.labels_paths[idx], dtype='int', aff_ref=np.eye(4)) + y = utils.load_volume( + self.labels_paths[idx], dtype="int", aff_ref=np.eye(4) + ) list_label_maps.append(utils.add_axis(y, axis=[0, -1])) # add means and standard deviations to inputs @@ -222,35 +237,61 @@ def _build_model_inputs(self, n_labels): # retrieve channel specific stats if necessary if isinstance(self.prior_means, np.ndarray): - if (self.prior_means.shape[0] > 2) & self.use_specific_stats_for_channel: + if ( + self.prior_means.shape[0] > 2 + ) & self.use_specific_stats_for_channel: if self.prior_means.shape[0] / 2 != self.n_channels: - raise ValueError("the number of blocks in prior_means does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_means = self.prior_means[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_means does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_means = self.prior_means[ + 2 * channel : 2 * channel + 2, : + ] else: tmp_prior_means = self.prior_means else: tmp_prior_means = self.prior_means if isinstance(self.prior_stds, np.ndarray): - if (self.prior_stds.shape[0] > 2) & self.use_specific_stats_for_channel: + if ( + self.prior_stds.shape[0] > 2 + ) & self.use_specific_stats_for_channel: if self.prior_stds.shape[0] / 2 != self.n_channels: - raise ValueError("the number of blocks in prior_stds does not match n_channels. This " - "message is printed because use_specific_stats_for_channel is True.") - tmp_prior_stds = self.prior_stds[2 * channel:2 * channel + 2, :] + raise ValueError( + "the number of blocks in prior_stds does not match n_channels. This " + "message is printed because use_specific_stats_for_channel is True." + ) + tmp_prior_stds = self.prior_stds[ + 2 * channel : 2 * channel + 2, : + ] else: tmp_prior_stds = self.prior_stds else: tmp_prior_stds = self.prior_stds # draw means and std devs from priors - tmp_classes_means = utils.draw_value_from_distribution(tmp_prior_means, n_labels, - self.prior_distributions, 125., 100., - positive_only=True) - tmp_classes_stds = utils.draw_value_from_distribution(tmp_prior_stds, n_labels, - self.prior_distributions, 15., 10., - positive_only=True) - tmp_means = utils.add_axis(tmp_classes_means[self.generation_classes], axis=[0, -1]) - tmp_stds = utils.add_axis(tmp_classes_stds[self.generation_classes], axis=[0, -1]) + tmp_classes_means = utils.draw_value_from_distribution( + tmp_prior_means, + n_labels, + self.prior_distributions, + 125.0, + 100.0, + positive_only=True, + ) + tmp_classes_stds = utils.draw_value_from_distribution( + tmp_prior_stds, + n_labels, + self.prior_distributions, + 15.0, + 10.0, + positive_only=True, + ) + tmp_means = utils.add_axis( + tmp_classes_means[self.generation_classes], axis=[0, -1] + ) + tmp_stds = utils.add_axis( + tmp_classes_stds[self.generation_classes], axis=[0, -1] + ) means = np.concatenate([means, tmp_means], axis=-1) stds = np.concatenate([stds, tmp_stds], axis=-1) list_means.append(means) @@ -258,7 +299,9 @@ def _build_model_inputs(self, n_labels): # build list of inputs of augmentation model list_inputs = [list_label_maps, list_means, list_stds] - if self.batchsize > 1: # concatenate individual input types if batchsize > 1 + if ( + self.batchsize > 1 + ): # concatenate individual input types if batchsize > 1 list_inputs = [np.concatenate(item, 0) for item in list_inputs] else: list_inputs = [item[0] for item in list_inputs] diff --git a/nobrainer/ext/lab2im/lab2im_model.py b/nobrainer/ext/lab2im/lab2im_model.py index f32c5b5a..96b0ee8b 100644 --- a/nobrainer/ext/lab2im/lab2im_model.py +++ b/nobrainer/ext/lab2im/lab2im_model.py @@ -13,27 +13,31 @@ License. """ +import keras.layers as KL +from keras.models import Model # python imports import numpy as np -import keras.layers as KL -from keras.models import Model # project imports -from nobrainer.ext.lab2im import utils -from nobrainer.ext.lab2im import layers -from nobrainer.ext.lab2im.edit_tensors import resample_tensor, blurring_sigma_for_downsampling - - -def lab2im_model(labels_shape, - n_channels, - generation_labels, - output_labels, - atlas_res, - target_res, - output_shape=None, - output_div_by_n=None, - blur_range=1.15): +from nobrainer.ext.lab2im import layers, utils +from nobrainer.ext.lab2im.edit_tensors import ( + blurring_sigma_for_downsampling, + resample_tensor, +) + + +def lab2im_model( + labels_shape, + n_channels, + generation_labels, + output_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + blur_range=1.15, +): """ This function builds a keras/tensorflow model to generate images from provided label maps. The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. @@ -74,18 +78,30 @@ def lab2im_model(labels_shape, labels_shape = utils.reformat_to_list(labels_shape) n_dims, _ = utils.get_dims(labels_shape) atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims=n_dims)[0] - target_res = atlas_res if (target_res is None) else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + target_res = ( + atlas_res + if (target_res is None) + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) # get shapes - crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + crop_shape, output_shape = get_shapes( + labels_shape, output_shape, atlas_res, target_res, output_div_by_n + ) # define model inputs - labels_input = KL.Input(shape=labels_shape+[1], name='labels_input', dtype='int32') - means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') - stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='stds_input') + labels_input = KL.Input( + shape=labels_shape + [1], name="labels_input", dtype="int32" + ) + means_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="means_input" + ) + stds_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="stds_input" + ) # deform labels - labels = layers.RandomSpatialDeformation(inter_method='nearest')(labels_input) + labels = layers.RandomSpatialDeformation(inter_method="nearest")(labels_input) # cropping if crop_shape != labels_shape: @@ -94,15 +110,19 @@ def lab2im_model(labels_shape, # build synthetic image labels._keras_shape = tuple(labels.get_shape().as_list()) - image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + image = layers.SampleConditionalGMM(generation_labels)( + [labels, means_input, stds_input] + ) # apply bias field image._keras_shape = tuple(image.get_shape().as_list()) - image = layers.BiasFieldCorruption(.3, .025, same_bias_for_all_channels=False)(image) + image = layers.BiasFieldCorruption(0.3, 0.025, same_bias_for_all_channels=False)( + image + ) # intensity augmentation image._keras_shape = tuple(image.get_shape().as_list()) - image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.2)(image) + image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=0.2)(image) # blur image sigma = blurring_sigma_for_downsampling(atlas_res, target_res) @@ -111,15 +131,19 @@ def lab2im_model(labels_shape, # resample to target res if crop_shape != output_shape: - image = resample_tensor(image, output_shape, interp_method='linear') - labels = resample_tensor(labels, output_shape, interp_method='nearest') + image = resample_tensor(image, output_shape, interp_method="linear") + labels = resample_tensor(labels, output_shape, interp_method="nearest") # reset unwanted labels to zero - labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + labels = layers.ConvertLabels( + generation_labels, dest_values=output_labels, name="labels_out" + )(labels) # build model (dummy layer enables to keep the labels when plugging this model to other models) - image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) - brain_model = Model(inputs=[labels_input, means_input, stds_input], outputs=[image, labels]) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) + brain_model = Model( + inputs=[labels_input, means_input, stds_input], outputs=[image, labels] + ) return brain_model @@ -136,26 +160,39 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ # output shape specified, need to get cropping shape, and resample shape if necessary if output_shape is not None: - output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype="int") # make sure that output shape is smaller or equal to label shape if resample_factor is not None: - output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) + for i in range(n_dims) + ] else: - output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(labels_shape[i], output_shape[i]) for i in range(n_dims) + ] # make sure output shape is divisible by output_div_by_n if output_div_by_n is not None: - tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) - for s in output_shape] + tmp_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] if output_shape != tmp_shape: - print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, - tmp_shape)) + print( + "output shape {0} not divisible by {1}, changed to {2}".format( + output_shape, output_div_by_n, tmp_shape + ) + ) output_shape = tmp_shape # get cropping and resample shape if resample_factor is not None: - cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] else: cropping_shape = output_shape @@ -163,12 +200,19 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ else: cropping_shape = labels_shape if resample_factor is not None: - output_shape = [int(np.around(cropping_shape[i]*resample_factor[i], 0)) for i in range(n_dims)] + output_shape = [ + int(np.around(cropping_shape[i] * resample_factor[i], 0)) + for i in range(n_dims) + ] else: output_shape = cropping_shape # make sure output shape is divisible by output_div_by_n if output_div_by_n is not None: - output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n, answer_type='closer') - for s in output_shape] + output_shape = [ + utils.find_closest_number_divisible_by_m( + s, output_div_by_n, answer_type="closer" + ) + for s in output_shape + ] return cropping_shape, output_shape diff --git a/nobrainer/ext/lab2im/layers.py b/nobrainer/ext/lab2im/layers.py index e477d607..c171e428 100644 --- a/nobrainer/ext/lab2im/layers.py +++ b/nobrainer/ext/lab2im/layers.py @@ -34,17 +34,16 @@ License. """ - # python imports import keras -import numpy as np -import tensorflow as tf import keras.backend as K from keras.layers import Layer +import numpy as np +import tensorflow as tf # project imports -from nobrainer.ext.lab2im import utils from nobrainer.ext.lab2im import edit_tensors as l2i_et +from nobrainer.ext.lab2im import utils # third-party imports from nobrainer.ext.neuron import utils as nrn_utils @@ -85,17 +84,19 @@ class RandomSpatialDeformation(Layer): :param prob_deform: (optional) probability to apply spatial deformation """ - def __init__(self, - scaling_bounds=0.15, - rotation_bounds=10, - shearing_bounds=0.02, - translation_bounds=False, - enable_90_rotations=False, - nonlin_std=4., - nonlin_scale=.0625, - inter_method='linear', - prob_deform=1, - **kwargs): + def __init__( + self, + scaling_bounds=0.15, + rotation_bounds=10, + shearing_bounds=0.02, + translation_bounds=False, + enable_90_rotations=False, + nonlin_std=4.0, + nonlin_scale=0.0625, + inter_method="linear", + prob_deform=1, + **kwargs + ): # shape attributes self.n_inputs = 1 @@ -113,9 +114,13 @@ def __init__(self, self.nonlin_scale = nonlin_scale # boolean attributes - self.apply_affine_trans = (self.scaling_bounds is not False) | (self.rotation_bounds is not False) | \ - (self.shearing_bounds is not False) | (self.translation_bounds is not False) | \ - self.enable_90_rotations + self.apply_affine_trans = ( + (self.scaling_bounds is not False) + | (self.rotation_bounds is not False) + | (self.shearing_bounds is not False) + | (self.translation_bounds is not False) + | self.enable_90_rotations + ) self.apply_elastic_trans = self.nonlin_std > 0 self.prob_deform = prob_deform @@ -148,12 +153,15 @@ def build(self, input_shape): self.n_dims = len(self.inshape) - 1 if self.apply_elastic_trans: - self.small_shape = utils.get_resample_shape(self.inshape[:self.n_dims], - self.nonlin_scale, self.n_dims) + self.small_shape = utils.get_resample_shape( + self.inshape[: self.n_dims], self.nonlin_scale, self.n_dims + ) else: self.small_shape = None - self.inter_method = utils.reformat_to_list(self.inter_method, length=self.n_inputs, dtype='str') + self.inter_method = utils.reformat_to_list( + self.inter_method, length=self.n_inputs, dtype="str" + ) self.built = True super(RandomSpatialDeformation, self).build(input_shape) @@ -164,7 +172,7 @@ def call(self, inputs, **kwargs): if self.n_inputs < 2: inputs = [inputs] types = [v.dtype for v in inputs] - inputs = [tf.cast(v, dtype='float32') for v in inputs] + inputs = [tf.cast(v, dtype="float32") for v in inputs] batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] # initialise list of transforms to operate @@ -172,39 +180,61 @@ def call(self, inputs, **kwargs): # add affine deformation to inputs list if self.apply_affine_trans: - affine_trans = utils.sample_affine_transform(batchsize, - self.n_dims, - self.rotation_bounds, - self.scaling_bounds, - self.shearing_bounds, - self.translation_bounds, - self.enable_90_rotations) + affine_trans = utils.sample_affine_transform( + batchsize, + self.n_dims, + self.rotation_bounds, + self.scaling_bounds, + self.shearing_bounds, + self.translation_bounds, + self.enable_90_rotations, + ) list_trans.append(affine_trans) # prepare non-linear deformation field and add it to inputs list if self.apply_elastic_trans: # sample small field from normal distribution of specified std dev - trans_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_shape, dtype='int32')], axis=0) + trans_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.small_shape, dtype="int32")], + axis=0, + ) trans_std = tf.random.uniform((1, 1), maxval=self.nonlin_std) elastic_trans = tf.random.normal(trans_shape, stddev=trans_std) # reshape this field to half size (for smoother SVF), integrate it, and reshape to full image size - resize_shape = [max(int(self.inshape[i] / 2), self.small_shape[i]) for i in range(self.n_dims)] - elastic_trans = nrn_layers.Resize(size=resize_shape, interp_method='linear')(elastic_trans) + resize_shape = [ + max(int(self.inshape[i] / 2), self.small_shape[i]) + for i in range(self.n_dims) + ] + elastic_trans = nrn_layers.Resize( + size=resize_shape, interp_method="linear" + )(elastic_trans) elastic_trans = nrn_layers.VecInt()(elastic_trans) - elastic_trans = nrn_layers.Resize(size=self.inshape[:self.n_dims], interp_method='linear')(elastic_trans) + elastic_trans = nrn_layers.Resize( + size=self.inshape[: self.n_dims], interp_method="linear" + )(elastic_trans) list_trans.append(elastic_trans) # apply deformations and return tensors with correct dtype if self.apply_affine_trans | self.apply_elastic_trans: if self.prob_deform == 1: - inputs = [nrn_layers.SpatialTransformer(m)([v] + list_trans) for (m, v) in - zip(self.inter_method, inputs)] + inputs = [ + nrn_layers.SpatialTransformer(m)([v] + list_trans) + for (m, v) in zip(self.inter_method, inputs) + ] else: - rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_deform)) - inputs = [K.switch(rand_trans, nrn_layers.SpatialTransformer(m)([v] + list_trans), v) - for (m, v) in zip(self.inter_method, inputs)] + rand_trans = tf.squeeze( + K.less(tf.random.uniform([1], 0, 1), self.prob_deform) + ) + inputs = [ + K.switch( + rand_trans, + nrn_layers.SpatialTransformer(m)([v] + list_trans), + v, + ) + for (m, v) in zip(self.inter_method, inputs) + ] if self.n_inputs < 2: return tf.cast(inputs[0], types[0]) else: @@ -244,7 +274,9 @@ def build(self, input_shape): inputshape = [input_shape] else: inputshape = input_shape - self.crop_max_val = np.array(np.array(inputshape[0][1:self.n_dims + 1])) - np.array(self.crop_shape) + self.crop_max_val = np.array( + np.array(inputshape[0][1 : self.n_dims + 1]) + ) - np.array(self.crop_shape) self.list_n_channels = [i[-1] for i in inputshape] self.built = True super(RandomCrop, self).build(input_shape) @@ -258,19 +290,24 @@ def call(self, inputs, **kwargs): # otherwise we concatenate all inputs before cropping, so that they are all cropped at the same location else: types = [v.dtype for v in inputs] - inputs = tf.concat([tf.cast(v, 'float32') for v in inputs], axis=-1) + inputs = tf.concat([tf.cast(v, "float32") for v in inputs], axis=-1) inputs = tf.map_fn(self._single_slice, inputs, dtype=tf.float32) inputs = tf.split(inputs, self.list_n_channels, axis=-1) return [tf.cast(v, t) for (t, v) in zip(types, inputs)] def _single_slice(self, vol): - crop_idx = tf.cast(tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), 'float32'), dtype='int32') - crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype='int32')], axis=0) - crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype='int32') + crop_idx = tf.cast( + tf.random.uniform([self.n_dims], 0, np.array(self.crop_max_val), "float32"), + dtype="int32", + ) + crop_idx = tf.concat([crop_idx, tf.zeros([1], dtype="int32")], axis=0) + crop_size = tf.convert_to_tensor(self.crop_shape + [-1], dtype="int32") return tf.slice(vol, begin=crop_idx, size=crop_size) def compute_output_shape(self, input_shape): - output_shape = [tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels] + output_shape = [ + tuple([None] + self.crop_shape + [v]) for v in self.list_n_channels + ] return output_shape if self.several_inputs else output_shape[0] @@ -329,7 +366,15 @@ class RandomFlip(Layer): This doesn't concern the image input, as its values are not swapped. """ - def __init__(self, axis=None, swap_labels=False, label_list=None, n_neutral_labels=None, prob=0.5, **kwargs): + def __init__( + self, + axis=None, + swap_labels=False, + label_list=None, + n_neutral_labels=None, + prob=0.5, + **kwargs + ): # shape attributes self.several_inputs = True @@ -368,22 +413,35 @@ def build(self, input_shape): inputshape = input_shape self.n_dims = len(inputshape[0][1:-1]) self.list_n_channels = [i[-1] for i in inputshape] - self.swap_labels = utils.reformat_to_list(self.swap_labels, length=len(inputshape)) - self.flip_axes = np.arange(self.n_dims).tolist() if self.axis is None else self.axis + self.swap_labels = utils.reformat_to_list( + self.swap_labels, length=len(inputshape) + ) + self.flip_axes = ( + np.arange(self.n_dims).tolist() if self.axis is None else self.axis + ) # create label list with swapped labels if any(self.swap_labels): - assert (self.label_list is not None) & (self.n_neutral_labels is not None), \ - 'please provide a label_list, and n_neutral_labels when swapping the values of at least one input' + assert (self.label_list is not None) & ( + self.n_neutral_labels is not None + ), "please provide a label_list, and n_neutral_labels when swapping the values of at least one input" n_labels = len(self.label_list) if self.n_neutral_labels == n_labels: self.swap_labels = [False] * len(self.swap_labels) else: - rl_split = np.split(self.label_list, [self.n_neutral_labels, - self.n_neutral_labels + int((n_labels-self.n_neutral_labels)/2)]) - label_list_swap = np.concatenate((rl_split[0], rl_split[2], rl_split[1])) + rl_split = np.split( + self.label_list, + [ + self.n_neutral_labels, + self.n_neutral_labels + + int((n_labels - self.n_neutral_labels) / 2), + ], + ) + label_list_swap = np.concatenate( + (rl_split[0], rl_split[2], rl_split[1]) + ) swap_lut = utils.get_mapping_lut(self.label_list, label_list_swap) - self.swap_lut = tf.convert_to_tensor(swap_lut, dtype='int32') + self.swap_lut = tf.convert_to_tensor(swap_lut, dtype="int32") self.built = True super(RandomFlip, self).build(input_shape) @@ -396,20 +454,29 @@ def call(self, inputs, **kwargs): # store whether to flip along each specified dimension batchsize = tf.split(tf.shape(inputs[0]), [1, self.n_dims + 1])[0] - size = tf.concat([batchsize, len(self.flip_axes) * tf.ones(1, dtype='int32')], axis=0) + size = tf.concat( + [batchsize, len(self.flip_axes) * tf.ones(1, dtype="int32")], axis=0 + ) rand_flip = K.less(tf.random.uniform(size, 0, 1), self.prob) # swap right/left labels if we apply an odd number of flips - odd = tf.math.floormod(tf.reduce_sum(tf.cast(rand_flip, 'int32'), -1, keepdims=True), 2) != 0 + odd = ( + tf.math.floormod( + tf.reduce_sum(tf.cast(rand_flip, "int32"), -1, keepdims=True), 2 + ) + != 0 + ) swapped_inputs = list() for i in range(len(inputs)): if self.swap_labels[i]: - swapped_inputs.append(tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i])) + swapped_inputs.append( + tf.map_fn(self._single_swap, [inputs[i], odd], dtype=types[i]) + ) else: swapped_inputs.append(inputs[i]) # flip inputs and convert them back to their original type - inputs = tf.concat([tf.cast(v, 'float32') for v in swapped_inputs], axis=-1) + inputs = tf.concat([tf.cast(v, "float32") for v in swapped_inputs], axis=-1) inputs = tf.map_fn(self._single_flip, [inputs, rand_flip], dtype=tf.float32) inputs = tf.split(inputs, self.list_n_channels, axis=-1) @@ -424,7 +491,11 @@ def _single_swap(self, inputs): @staticmethod def _single_flip(inputs): flip_axis = tf.where(inputs[1]) - return K.switch(tf.equal(tf.size(flip_axis), 0), inputs[0], tf.reverse(inputs[0], axis=flip_axis[..., 0])) + return K.switch( + tf.equal(tf.size(flip_axis), 0), + inputs[0], + tf.reverse(inputs[0], axis=flip_axis[..., 0]), + ) class SampleConditionalGMM(Layer): @@ -462,17 +533,31 @@ def get_config(self): def build(self, input_shape): # check n_labels and n_channels - assert len(input_shape) == 3, 'should have three inputs: labels, means, std devs (in that order).' + assert ( + len(input_shape) == 3 + ), "should have three inputs: labels, means, std devs (in that order)." self.n_channels = input_shape[1][-1] self.n_labels = len(self.generation_labels) - assert self.n_labels == input_shape[1][1], 'means should have the same number of values as generation_labels' - assert self.n_labels == input_shape[2][1], 'stds should have the same number of values as generation_labels' + assert ( + self.n_labels == input_shape[1][1] + ), "means should have the same number of values as generation_labels" + assert ( + self.n_labels == input_shape[2][1] + ), "stds should have the same number of values as generation_labels" # scatter parameters (to build mean/std lut) self.max_label = np.max(self.generation_labels) + 1 - indices = np.concatenate([self.generation_labels + self.max_label * i for i in range(self.n_channels)], axis=-1) - self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype='int32') - self.indices = tf.convert_to_tensor(utils.add_axis(indices, axis=[0, -1]), dtype='int32') + indices = np.concatenate( + [ + self.generation_labels + self.max_label * i + for i in range(self.n_channels) + ], + axis=-1, + ) + self.shape = tf.convert_to_tensor([np.max(indices) + 1], dtype="int32") + self.indices = tf.convert_to_tensor( + utils.add_axis(indices, axis=[0, -1]), dtype="int32" + ) self.built = True super(SampleConditionalGMM, self).build(input_shape) @@ -481,24 +566,56 @@ def call(self, inputs, **kwargs): # reformat labels and scatter indices batch = tf.split(tf.shape(inputs[0]), [1, -1])[0] - tmp_indices = tf.tile(self.indices, tf.concat([batch, tf.convert_to_tensor([1, 1], dtype='int32')], axis=0)) - labels = tf.concat([tf.cast(inputs[0], dtype='int32') + self.max_label * i for i in range(self.n_channels)], -1) + tmp_indices = tf.tile( + self.indices, + tf.concat([batch, tf.convert_to_tensor([1, 1], dtype="int32")], axis=0), + ) + labels = tf.concat( + [ + tf.cast(inputs[0], dtype="int32") + self.max_label * i + for i in range(self.n_channels) + ], + -1, + ) # build mean map means = tf.concat([inputs[1][..., i] for i in range(self.n_channels)], 1) - tile_shape = tf.concat([batch, tf.convert_to_tensor([1, ], dtype='int32')], axis=0) - means = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape) - means_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32) + tile_shape = tf.concat( + [ + batch, + tf.convert_to_tensor( + [ + 1, + ], + dtype="int32", + ), + ], + axis=0, + ) + means = tf.tile( + tf.expand_dims(tf.scatter_nd(tmp_indices, means, self.shape), 0), tile_shape + ) + means_map = tf.map_fn( + lambda x: tf.gather(x[0], x[1]), [means, labels], dtype=tf.float32 + ) # same for stds stds = tf.concat([inputs[2][..., i] for i in range(self.n_channels)], 1) - stds = tf.tile(tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape) - stds_map = tf.map_fn(lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32) + stds = tf.tile( + tf.expand_dims(tf.scatter_nd(tmp_indices, stds, self.shape), 0), tile_shape + ) + stds_map = tf.map_fn( + lambda x: tf.gather(x[0], x[1]), [stds, labels], dtype=tf.float32 + ) return stds_map * tf.random.normal(tf.shape(labels)) + means_map def compute_output_shape(self, input_shape): - return input_shape[0] if (self.n_channels == 1) else tuple(list(input_shape[0][:-1]) + [self.n_channels]) + return ( + input_shape[0] + if (self.n_channels == 1) + else tuple(list(input_shape[0][:-1]) + [self.n_channels]) + ) class SampleResolution(Layer): @@ -528,14 +645,16 @@ class SampleResolution(Layer): """ - def __init__(self, - min_resolution, - max_res_iso=None, - max_res_aniso=None, - prob_iso=0.1, - prob_min=0.05, - return_thickness=True, - **kwargs): + def __init__( + self, + min_resolution, + max_res_iso=None, + max_res_aniso=None, + prob_iso=0.1, + prob_min=0.05, + return_thickness=True, + **kwargs + ): self.min_res = min_resolution self.max_res_iso_input = max_res_iso @@ -563,34 +682,43 @@ def get_config(self): def build(self, input_shape): # check maximum resolutions - assert ((self.max_res_iso_input is not None) | (self.max_res_aniso_input is not None)), \ - 'at least one of maximum isotropic or anisotropic resolutions must be provided, received none' + assert (self.max_res_iso_input is not None) | ( + self.max_res_aniso_input is not None + ), "at least one of maximum isotropic or anisotropic resolutions must be provided, received none" # reformat resolutions as numpy arrays self.min_res = np.array(self.min_res) if self.max_res_iso_input is not None: self.max_res_iso = np.array(self.max_res_iso_input) - assert len(self.min_res) == len(self.max_res_iso), \ - 'min and isotropic max resolution must have the same length, ' \ - 'had {0} and {1}'.format(self.min_res, self.max_res_iso) + assert len(self.min_res) == len(self.max_res_iso), ( + "min and isotropic max resolution must have the same length, " + "had {0} and {1}".format(self.min_res, self.max_res_iso) + ) if np.array_equal(self.min_res, self.max_res_iso): self.max_res_iso = None if self.max_res_aniso_input is not None: self.max_res_aniso = np.array(self.max_res_aniso_input) - assert len(self.min_res) == len(self.max_res_aniso), \ - 'min and anisotropic max resolution must have the same length, ' \ - 'had {} and {}'.format(self.min_res, self.max_res_aniso) + assert len(self.min_res) == len(self.max_res_aniso), ( + "min and anisotropic max resolution must have the same length, " + "had {} and {}".format(self.min_res, self.max_res_aniso) + ) if np.array_equal(self.min_res, self.max_res_aniso): self.max_res_aniso = None # check prob iso - if (self.max_res_iso is not None) & (self.max_res_aniso is not None) & (self.prob_iso == 0): - raise Exception('prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled') + if ( + (self.max_res_iso is not None) + & (self.max_res_aniso is not None) + & (self.prob_iso == 0) + ): + raise Exception( + "prob iso is 0 while sampling either isotropic and anisotropic resolutions is enabled" + ) if input_shape: self.add_batchsize = True - self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32') + self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype="float32") self.built = True super(SampleResolution, self).build(input_shape) @@ -599,17 +727,36 @@ def call(self, inputs, **kwargs): if not self.add_batchsize: shape = [self.n_dims] - dim = tf.random.uniform(shape=(1, 1), minval=0, maxval=self.n_dims, dtype='int32') - mask = tf.tensor_scatter_nd_update(tf.zeros([self.n_dims], dtype='bool'), dim, - tf.convert_to_tensor([True], dtype='bool')) + dim = tf.random.uniform( + shape=(1, 1), minval=0, maxval=self.n_dims, dtype="int32" + ) + mask = tf.tensor_scatter_nd_update( + tf.zeros([self.n_dims], dtype="bool"), + dim, + tf.convert_to_tensor([True], dtype="bool"), + ) else: batch = tf.split(tf.shape(inputs), [1, -1])[0] - tile_shape = tf.concat([batch, tf.convert_to_tensor([1], dtype='int32')], axis=0) - self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape) - - shape = tf.concat([batch, tf.convert_to_tensor([self.n_dims], dtype='int32')], axis=0) - indices = tf.stack([tf.range(0, batch[0]), tf.random.uniform(batch, 0, self.n_dims, dtype='int32')], 1) - mask = tf.tensor_scatter_nd_update(tf.zeros(shape, dtype='bool'), indices, tf.ones(batch, dtype='bool')) + tile_shape = tf.concat( + [batch, tf.convert_to_tensor([1], dtype="int32")], axis=0 + ) + self.min_res_tens = tf.tile( + tf.expand_dims(self.min_res_tens, 0), tile_shape + ) + + shape = tf.concat( + [batch, tf.convert_to_tensor([self.n_dims], dtype="int32")], axis=0 + ) + indices = tf.stack( + [ + tf.range(0, batch[0]), + tf.random.uniform(batch, 0, self.n_dims, dtype="int32"), + ], + 1, + ) + mask = tf.tensor_scatter_nd_update( + tf.zeros(shape, dtype="bool"), indices, tf.ones(batch, dtype="bool") + ) # return min resolution as tensor if min=max if (self.max_res_iso is None) & (self.max_res_aniso is None): @@ -617,37 +764,60 @@ def call(self, inputs, **kwargs): # sample isotropic resolution only elif (self.max_res_iso is not None) & (self.max_res_aniso is None): - new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) - new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), - self.min_res_tens, - new_resolution_iso) + new_resolution_iso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_iso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution_iso, + ) # sample anisotropic resolution only elif (self.max_res_iso is None) & (self.max_res_aniso is not None): - new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) - new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), - self.min_res_tens, - tf.where(mask, new_resolution_aniso, self.min_res_tens)) + new_resolution_aniso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_aniso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + tf.where(mask, new_resolution_aniso, self.min_res_tens), + ) # sample either anisotropic or isotropic resolution else: - new_resolution_iso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_iso) - new_resolution_aniso = tf.random.uniform(shape, minval=self.min_res, maxval=self.max_res_aniso) - new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)), - new_resolution_iso, - tf.where(mask, new_resolution_aniso, self.min_res_tens)) - new_resolution = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), - self.min_res_tens, - new_resolution) + new_resolution_iso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_iso + ) + new_resolution_aniso = tf.random.uniform( + shape, minval=self.min_res, maxval=self.max_res_aniso + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_iso)), + new_resolution_iso, + tf.where(mask, new_resolution_aniso, self.min_res_tens), + ) + new_resolution = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_min)), + self.min_res_tens, + new_resolution, + ) if self.return_thickness: - return [new_resolution, tf.random.uniform(tf.shape(self.min_res_tens), self.min_res_tens, new_resolution)] + return [ + new_resolution, + tf.random.uniform( + tf.shape(self.min_res_tens), self.min_res_tens, new_resolution + ), + ] else: return new_resolution def compute_output_shape(self, input_shape): if self.return_thickness: - return [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2 + return ( + [(None, self.n_dims)] * 2 if self.add_batchsize else [self.n_dims] * 2 + ) else: return (None, self.n_dims) if self.add_batchsize else self.n_dims @@ -684,7 +854,9 @@ class GaussianBlur(Layer): def __init__(self, sigma, random_blur_range=None, use_mask=False, **kwargs): self.sigma = utils.reformat_to_list(sigma) - assert np.all(np.array(self.sigma) >= 0), 'sigma should be superior or equal to 0' + assert np.all( + np.array(self.sigma) >= 0 + ), "sigma should be superior or equal to 0" self.use_mask = use_mask self.n_dims = None @@ -707,7 +879,9 @@ def build(self, input_shape): # get shapes if self.use_mask: - assert len(input_shape) == 2, 'please provide a mask as second layer input when use_mask=True' + assert ( + len(input_shape) == 2 + ), "please provide a mask as second layer input when use_mask=True" self.n_dims = len(input_shape[0]) - 2 self.n_channels = input_shape[0][-1] else: @@ -724,7 +898,7 @@ def build(self, input_shape): self.kernels = None # prepare convolution - self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) self.built = True super(GaussianBlur, self).build(input_shape) @@ -733,34 +907,76 @@ def call(self, inputs, **kwargs): if self.use_mask: image = inputs[0] - mask = tf.cast(inputs[1], 'bool') + mask = tf.cast(inputs[1], "bool") else: image = inputs mask = None # redefine the kernels at each new step when blur_range is activated if self.blur_range is not None: - self.kernels = l2i_et.gaussian_kernel(self.sigma, blur_range=self.blur_range, separable=self.separable) + self.kernels = l2i_et.gaussian_kernel( + self.sigma, blur_range=self.blur_range, separable=self.separable + ) if self.separable: for k in self.kernels: if k is not None: - image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), k, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + image = tf.concat( + [ + self.convnd( + tf.expand_dims(image[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) if self.use_mask: - maskb = tf.cast(mask, 'float32') - maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), k, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + maskb = tf.cast(mask, "float32") + maskb = tf.concat( + [ + self.convnd( + tf.expand_dims(maskb[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) image = image / (maskb + K.epsilon()) image = tf.where(mask, image, tf.zeros_like(image)) else: if any(self.sigma): - image = tf.concat([self.convnd(tf.expand_dims(image[..., n], -1), self.kernels, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + image = tf.concat( + [ + self.convnd( + tf.expand_dims(image[..., n], -1), + self.kernels, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) if self.use_mask: - maskb = tf.cast(mask, 'float32') - maskb = tf.concat([self.convnd(tf.expand_dims(maskb[..., n], -1), self.kernels, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + maskb = tf.cast(mask, "float32") + maskb = tf.concat( + [ + self.convnd( + tf.expand_dims(maskb[..., n], -1), + self.kernels, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) image = image / (maskb + K.epsilon()) image = tf.where(mask, image, tf.zeros_like(image)) @@ -798,10 +1014,12 @@ def get_config(self): return config def build(self, input_shape): - assert len(input_shape) == 2, 'sigma should be provided as an input tensor for dynamic blurring' + assert ( + len(input_shape) == 2 + ), "sigma should be provided as an input tensor for dynamic blurring" self.n_dims = len(input_shape[0]) - 2 self.n_channels = input_shape[0][-1] - self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) self.max_sigma = utils.reformat_to_list(self.max_sigma, length=self.n_dims) self.separable = np.linalg.norm(np.array(self.max_sigma)) > 5 self.built = True @@ -810,7 +1028,9 @@ def build(self, input_shape): def call(self, inputs, **kwargs): image = inputs[0] sigma = inputs[-1] - kernels = l2i_et.gaussian_kernel(sigma, self.max_sigma, self.blur_range, self.separable) + kernels = l2i_et.gaussian_kernel( + sigma, self.max_sigma, self.blur_range, self.separable + ) if self.separable: for kernel in kernels: image = tf.map_fn(self._single_blur, [image, kernel], dtype=tf.float32) @@ -823,11 +1043,21 @@ def _single_blur(self, inputs): split_channels = tf.split(inputs[0], [1] * self.n_channels, axis=-1) blurred_channel = list() for channel in split_channels: - blurred = self.convnd(tf.expand_dims(channel, 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') + blurred = self.convnd( + tf.expand_dims(channel, 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ) blurred_channel.append(tf.squeeze(blurred, axis=0)) output = tf.concat(blurred_channel, -1) else: - output = self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], [1] * (self.n_dims + 2), padding='SAME') + output = self.convnd( + tf.expand_dims(inputs[0], 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ) output = tf.squeeze(output, axis=0) return output @@ -872,8 +1102,16 @@ class MimicAcquisition(Layer): Note that the provided res must have higher values than min_low_res. """ - def __init__(self, volume_res, min_subsample_res, resample_shape, build_dist_map=False, - noise_std=0, prob_noise=0.95, **kwargs): + def __init__( + self, + volume_res, + min_subsample_res, + resample_shape, + build_dist_map=False, + noise_std=0, + prob_noise=0.95, + **kwargs + ): # resolutions and dimensions self.volume_res = volume_res @@ -915,11 +1153,17 @@ def build(self, input_shape): self.inshape = input_shape[0][1:] self.n_channels = input_shape[0][-1] self.add_batchsize = False if (input_shape[1][0] is None) else True - down_tensor_shape = np.int32(np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res) + down_tensor_shape = np.int32( + np.array(self.inshape[:-1]) * self.volume_res / self.min_subsample_res + ) # build interpolation meshgrids - self.down_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0) - self.up_grid = tf.expand_dims(tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0) + self.down_grid = tf.expand_dims( + tf.stack(nrn_utils.volshape_to_ndgrid(down_tensor_shape), -1), axis=0 + ) + self.up_grid = tf.expand_dims( + tf.stack(nrn_utils.volshape_to_ndgrid(self.resample_shape), -1), axis=0 + ) self.built = True super(MimicAcquisition, self).build(input_shape) @@ -927,42 +1171,79 @@ def build(self, input_shape): def call(self, inputs, **kwargs): # sort inputs - assert len(inputs) == 2, 'inputs must have two items, the tensor to resample, and the downsampling resolution' + assert ( + len(inputs) == 2 + ), "inputs must have two items, the tensor to resample, and the downsampling resolution" vol = inputs[0] - subsample_res = tf.cast(inputs[1], dtype='float32') + subsample_res = tf.cast(inputs[1], dtype="float32") vol = K.reshape(vol, [-1, *self.inshape]) # necessary for multi_gpu models batchsize = tf.split(tf.shape(vol), [1, -1])[0] - tile_shape = tf.concat([batchsize, tf.ones([1], dtype='int32')], 0) + tile_shape = tf.concat([batchsize, tf.ones([1], dtype="int32")], 0) # get downsampling and upsampling factors if self.add_batchsize: subsample_res = tf.tile(tf.expand_dims(subsample_res, 0), tile_shape) - down_shape = tf.cast(tf.convert_to_tensor(np.array(self.inshape[:-1]) * self.volume_res, dtype='float32') / - subsample_res, dtype='int32') - down_zoom_factor = tf.cast(down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype='float32') - up_zoom_factor = tf.cast(tf.convert_to_tensor(self.resample_shape, dtype='int32') / down_shape, dtype='float32') + down_shape = tf.cast( + tf.convert_to_tensor( + np.array(self.inshape[:-1]) * self.volume_res, dtype="float32" + ) + / subsample_res, + dtype="int32", + ) + down_zoom_factor = tf.cast( + down_shape / tf.convert_to_tensor(self.inshape[:-1]), dtype="float32" + ) + up_zoom_factor = tf.cast( + tf.convert_to_tensor(self.resample_shape, dtype="int32") / down_shape, + dtype="float32", + ) # downsample - down_loc = tf.tile(self.down_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], 0)) - down_loc = tf.cast(down_loc, 'float32') / l2i_et.expand_dims(down_zoom_factor, axis=[1] * self.n_dims) - inshape_tens = tf.tile(tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape) + down_loc = tf.tile( + self.down_grid, + tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype="int32")], 0), + ) + down_loc = tf.cast(down_loc, "float32") / l2i_et.expand_dims( + down_zoom_factor, axis=[1] * self.n_dims + ) + inshape_tens = tf.tile( + tf.expand_dims(tf.convert_to_tensor(self.inshape[:-1]), 0), tile_shape + ) inshape_tens = l2i_et.expand_dims(inshape_tens, axis=[1] * self.n_dims) - down_loc = K.clip(down_loc, 0., tf.cast(inshape_tens, 'float32')) + down_loc = K.clip(down_loc, 0.0, tf.cast(inshape_tens, "float32")) vol = tf.map_fn(self._single_down_interpn, [vol, down_loc], tf.float32) # add noise with predefined probability if self.noise_std > 0: - sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32'), - self.n_channels * tf.ones([1], dtype='int32')], 0) - noise = tf.random.normal(tf.shape(vol), stddev=tf.random.uniform(sample_shape, maxval=self.noise_std)) + sample_shape = tf.concat( + [ + batchsize, + tf.ones([self.n_dims], dtype="int32"), + self.n_channels * tf.ones([1], dtype="int32"), + ], + 0, + ) + noise = tf.random.normal( + tf.shape(vol), + stddev=tf.random.uniform(sample_shape, maxval=self.noise_std), + ) if self.prob_noise == 1: vol += noise else: - vol = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), vol + noise, vol) + vol = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), + vol + noise, + vol, + ) # upsample - up_loc = tf.tile(self.up_grid, tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype='int32')], axis=0)) - up_loc = tf.cast(up_loc, 'float32') / l2i_et.expand_dims(up_zoom_factor, axis=[1] * self.n_dims) + up_loc = tf.tile( + self.up_grid, + tf.concat([batchsize, tf.ones([self.n_dims + 1], dtype="int32")], axis=0), + ) + up_loc = tf.cast(up_loc, "float32") / l2i_et.expand_dims( + up_zoom_factor, axis=[1] * self.n_dims + ) vol = tf.map_fn(self._single_up_interpn, [vol, up_loc], tf.float32) # return upsampled volume @@ -981,18 +1262,22 @@ def call(self, inputs, **kwargs): c_dist = ceil - up_loc # keep minimum 1d distances, and compute 3d distance to nearest grid point - dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims(subsample_res, axis=[1] * self.n_dims) - dist = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True)) + dist = tf.math.minimum(f_dist, c_dist) * l2i_et.expand_dims( + subsample_res, axis=[1] * self.n_dims + ) + dist = tf.math.sqrt( + tf.math.reduce_sum(tf.math.square(dist), axis=-1, keepdims=True) + ) return [vol, dist] @staticmethod def _single_down_interpn(inputs): - return nrn_utils.interpn(inputs[0], inputs[1], interp_method='nearest') + return nrn_utils.interpn(inputs[0], inputs[1], interp_method="nearest") @staticmethod def _single_up_interpn(inputs): - return nrn_utils.interpn(inputs[0], inputs[1], interp_method='linear') + return nrn_utils.interpn(inputs[0], inputs[1], interp_method="linear") def compute_output_shape(self, input_shape): output_shape = tuple([None] + self.resample_shape + [input_shape[0][-1]]) @@ -1015,7 +1300,14 @@ class BiasFieldCorruption(Layer): :param prob: probability to apply this bias field corruption. """ - def __init__(self, bias_field_std=.5, bias_scale=.025, same_bias_for_all_channels=False, prob=0.95, **kwargs): + def __init__( + self, + bias_field_std=0.5, + bias_scale=0.025, + same_bias_for_all_channels=False, + prob=0.95, + **kwargs + ): # input shape self.several_inputs = False @@ -1056,7 +1348,9 @@ def build(self, input_shape): # sampling shapes self.std_shape = [1] * (self.n_dims + 1) - self.small_bias_shape = utils.get_resample_shape(self.inshape[0][1:self.n_dims + 1], self.bias_scale, 1) + self.small_bias_shape = utils.get_resample_shape( + self.inshape[0][1 : self.n_dims + 1], self.bias_scale, 1 + ) if not self.same_bias_for_all_channels: self.std_shape[-1] = self.n_channels self.small_bias_shape[-1] = self.n_channels @@ -1073,14 +1367,24 @@ def call(self, inputs, **kwargs): # sampling shapes batchsize = tf.split(tf.shape(inputs[0]), [1, -1])[0] - std_shape = tf.concat([batchsize, tf.convert_to_tensor(self.std_shape, dtype='int32')], 0) - bias_shape = tf.concat([batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype='int32')], axis=0) + std_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.std_shape, dtype="int32")], 0 + ) + bias_shape = tf.concat( + [batchsize, tf.convert_to_tensor(self.small_bias_shape, dtype="int32")], + axis=0, + ) # sample small bias field - bias_field = tf.random.normal(bias_shape, stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std)) + bias_field = tf.random.normal( + bias_shape, + stddev=tf.random.uniform(std_shape, maxval=self.bias_field_std), + ) # resize bias field and take exponential - bias_field = nrn_layers.Resize(size=self.inshape[0][1:self.n_dims + 1], interp_method='linear')(bias_field) + bias_field = nrn_layers.Resize( + size=self.inshape[0][1 : self.n_dims + 1], interp_method="linear" + )(bias_field) bias_field = tf.math.exp(bias_field) # apply bias field with predefined probability @@ -1089,9 +1393,14 @@ def call(self, inputs, **kwargs): else: rand_trans = tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob)) if self.several_inputs: - return [K.switch(rand_trans, tf.math.multiply(bias_field, v), v) for v in inputs] + return [ + K.switch(rand_trans, tf.math.multiply(bias_field, v), v) + for v in inputs + ] else: - return K.switch(rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0]) + return K.switch( + rand_trans, tf.math.multiply(bias_field, inputs[0]), inputs[0] + ) else: return inputs @@ -1125,8 +1434,19 @@ class IntensityAugmentation(Layer): :param prob_gamma: probability to apply gamma augmentation """ - def __init__(self, noise_std=0, clip=0, normalise=True, norm_perc=0, gamma_std=0, contrast_inversion=False, - separate_channels=True, prob_noise=0.95, prob_gamma=1, **kwargs): + def __init__( + self, + noise_std=0, + clip=0, + normalise=True, + norm_perc=0, + gamma_std=0, + contrast_inversion=False, + separate_channels=True, + prob_noise=0.95, + prob_gamma=1, + **kwargs + ): # shape attributes self.n_dims = None @@ -1166,17 +1486,29 @@ def build(self, input_shape): self.n_dims = len(input_shape) - 2 self.n_channels = input_shape[-1] self.flatten_shape = np.prod(np.array(input_shape[1:-1])) - self.flatten_shape = self.flatten_shape * self.n_channels if not self.separate_channels else self.flatten_shape - self.expand_minmax_dim = self.n_dims if self.separate_channels else self.n_dims + 1 - self.one = tf.ones([1], dtype='int32') + self.flatten_shape = ( + self.flatten_shape * self.n_channels + if not self.separate_channels + else self.flatten_shape + ) + self.expand_minmax_dim = ( + self.n_dims if self.separate_channels else self.n_dims + 1 + ) + self.one = tf.ones([1], dtype="int32") if self.clip: self.clip_values = utils.reformat_to_list(self.clip) - self.clip_values = self.clip_values if len(self.clip_values) == 2 else [0, self.clip_values[0]] + self.clip_values = ( + self.clip_values + if len(self.clip_values) == 2 + else [0, self.clip_values[0]] + ) else: self.clip_values = None if self.norm_perc: self.perc = utils.reformat_to_list(self.norm_perc) - self.perc = self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]] + self.perc = ( + self.perc if len(self.perc) == 2 else [self.perc[0], 1 - self.perc[0]] + ) else: self.perc = None @@ -1188,7 +1520,9 @@ def call(self, inputs, **kwargs): # prepare shape for sampling the noise and gamma std dev (depending on whether we augment channels separately) batchsize = tf.split(tf.shape(inputs), [1, -1])[0] if (self.noise_std > 0) | (self.gamma_std > 0) | self.contrast_inversion: - sample_shape = tf.concat([batchsize, tf.ones([self.n_dims], dtype='int32')], 0) + sample_shape = tf.concat( + [batchsize, tf.ones([self.n_dims], dtype="int32")], 0 + ) if self.separate_channels: sample_shape = tf.concat([sample_shape, self.n_channels * self.one], 0) else: @@ -1202,13 +1536,21 @@ def call(self, inputs, **kwargs): if self.separate_channels: noise = tf.random.normal(tf.shape(inputs), stddev=noise_stddev) else: - noise = tf.random.normal(tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev) - noise = tf.tile(noise, tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels])) + noise = tf.random.normal( + tf.shape(tf.split(inputs, [1, -1], -1)[0]), stddev=noise_stddev + ) + noise = tf.tile( + noise, + tf.convert_to_tensor([1] * (self.n_dims + 1) + [self.n_channels]), + ) if self.prob_noise == 1: inputs = inputs + noise else: - inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), - inputs + noise, inputs) + inputs = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_noise)), + inputs + noise, + inputs, + ) # clip images to given values if self.clip_values is not None: @@ -1219,12 +1561,23 @@ def call(self, inputs, **kwargs): # define robust min and max by sorting values and taking percentile if self.perc is not None: if self.separate_channels: - shape = tf.concat([batchsize, self.flatten_shape * self.one, self.n_channels * self.one], 0) + shape = tf.concat( + [ + batchsize, + self.flatten_shape * self.one, + self.n_channels * self.one, + ], + 0, + ) else: shape = tf.concat([batchsize, self.flatten_shape * self.one], 0) intensities = tf.sort(tf.reshape(inputs, shape), axis=1) m = intensities[:, max(int(self.perc[0] * self.flatten_shape), 0), ...] - M = intensities[:, min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), ...] + M = intensities[ + :, + min(int(self.perc[1] * self.flatten_shape), self.flatten_shape - 1), + ..., + ] # simple min and max else: m = K.min(inputs, axis=list(range(1, self.expand_minmax_dim + 1))) @@ -1241,8 +1594,11 @@ def call(self, inputs, **kwargs): if self.prob_gamma == 1: inputs = tf.math.pow(inputs, tf.math.exp(gamma)) else: - inputs = K.switch(tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)), - tf.math.pow(inputs, tf.math.exp(gamma)), inputs) + inputs = K.switch( + tf.squeeze(K.less(tf.random.uniform([1], 0, 1), self.prob_gamma)), + tf.math.pow(inputs, tf.math.exp(gamma)), + inputs, + ) # apply random contrast inversion if self.contrast_inversion: @@ -1250,8 +1606,12 @@ def call(self, inputs, **kwargs): split_channels = tf.split(inputs, [1] * self.n_channels, axis=-1) split_rand_invert = tf.split(rand_invert, [1] * self.n_channels, axis=-1) inverted_channel = list() - for (channel, invert) in zip(split_channels, split_rand_invert): - inverted_channel.append(tf.map_fn(self._single_invert, [channel, invert], dtype=channel.dtype)) + for channel, invert in zip(split_channels, split_rand_invert): + inverted_channel.append( + tf.map_fn( + self._single_invert, [channel, invert], dtype=channel.dtype + ) + ) inputs = tf.concat(inverted_channel, -1) return inputs @@ -1279,13 +1639,15 @@ class DiceLoss(Layer): probabilities sum to 1 at each voxel location). Default is True. """ - def __init__(self, - class_weights=None, - boundary_weights=0, - boundary_dist=3, - skip_background=True, - enable_checks=True, - **kwargs): + def __init__( + self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs + ): self.class_weights = class_weights self.dynamic_weighting = False @@ -1310,13 +1672,17 @@ def get_config(self): def build(self, input_shape): # get shape - assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "DiceLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." inshape = input_shape[0][1:] n_dims = len(inshape[:-1]) n_labels = inshape[-1] self.spatial_axes = list(range(1, n_dims + 1)) - self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) + self.avg_pooling_layer = getattr(keras.layers, "AvgPool%dD" % n_dims) self.skip_background = False if n_labels == 1 else self.skip_background # build tensor with class weights @@ -1324,8 +1690,10 @@ def build(self, input_shape): if self.class_weights == -1: self.dynamic_weighting = True else: - class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) - class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) self.built = True @@ -1336,9 +1704,27 @@ def call(self, inputs, **kwargs): # make sure tensors are probabilistic gt = inputs[0] pred = inputs[1] - if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps - gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) - pred = K.clip(pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = K.clip( + gt + / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) + pred = K.clip( + pred + / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) # compute dice loss for each label top = 2 * gt * pred @@ -1346,11 +1732,18 @@ def call(self, inputs, **kwargs): # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) if self.boundary_weights: - avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) - boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') + avg = self.avg_pooling_layer( + pool_size=2 * self.boundary_dist + 1, strides=1, padding="same" + )(gt) + boundaries = tf.cast(avg > 0.0, "float32") * tf.cast( + avg < (1 / len(self.spatial_axes) - 1e-4), "float32" + ) if self.skip_background: boundaries_channels = tf.unstack(boundaries, axis=-1) - boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) + boundaries = tf.stack( + [tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], + axis=-1, + ) boundary_weights_tensor = 1 + self.boundary_weights * boundaries top *= boundary_weights_tensor bottom *= boundary_weights_tensor @@ -1360,17 +1753,25 @@ def call(self, inputs, **kwargs): # compute loss top = tf.math.reduce_sum(top, self.spatial_axes) bottom = tf.math.reduce_sum(bottom, self.spatial_axes) - dice = (top + tf.keras.backend.epsilon()) / (bottom + tf.keras.backend.epsilon()) + dice = (top + tf.keras.backend.epsilon()) / ( + bottom + tf.keras.backend.epsilon() + ) loss = 1 - dice # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). - if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt - if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume - self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes) + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt + if ( + boundary_weights_tensor is not None + ): # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum( + gt * boundary_weights_tensor, self.spatial_axes + ) else: self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) if self.class_weights_tens is not None: - self. class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) + self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) loss = tf.reduce_sum(loss * self.class_weights_tens, -1) return tf.math.reduce_mean(loss) @@ -1399,8 +1800,12 @@ def get_config(self): return config def build(self, input_shape): - assert len(input_shape) == 2, 'DiceLoss expects 2 inputs to compute the Dice loss.' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "DiceLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." self.n_labels = input_shape[0][-1] self.built = True super(WeightedL2Loss, self).build(input_shape) @@ -1409,7 +1814,9 @@ def call(self, inputs, **kwargs): gt = inputs[0] pred = inputs[1] weights = tf.expand_dims(1 - gt[..., 0] + 1e-8, -1) - return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / (K.sum(weights) * self.n_labels) + return K.sum(weights * K.square(pred - self.target_value * (2 * gt - 1))) / ( + K.sum(weights) * self.n_labels + ) def compute_output_shape(self, input_shape): return [[]] @@ -1433,13 +1840,15 @@ class CrossEntropyLoss(Layer): probabilities sum to 1 at each voxel location). Default is True. """ - def __init__(self, - class_weights=None, - boundary_weights=0, - boundary_dist=3, - skip_background=True, - enable_checks=True, - **kwargs): + def __init__( + self, + class_weights=None, + boundary_weights=0, + boundary_dist=3, + skip_background=True, + enable_checks=True, + **kwargs + ): self.class_weights = class_weights self.dynamic_weighting = False @@ -1464,13 +1873,17 @@ def get_config(self): def build(self, input_shape): # get shape - assert len(input_shape) == 2, 'CrossEntropy expects 2 inputs to compute the Dice loss.' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "CrossEntropy expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." inshape = input_shape[0][1:] n_dims = len(inshape[:-1]) n_labels = inshape[-1] self.spatial_axes = list(range(1, n_dims + 1)) - self.avg_pooling_layer = getattr(keras.layers, 'AvgPool%dD' % n_dims) + self.avg_pooling_layer = getattr(keras.layers, "AvgPool%dD" % n_dims) self.skip_background = False if n_labels == 1 else self.skip_background # build tensor with class weights @@ -1478,9 +1891,13 @@ def build(self, input_shape): if self.class_weights == -1: self.dynamic_weighting = True else: - class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) - class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') - self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, [0] * (1 + n_dims)) + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") + self.class_weights_tens = l2i_et.expand_dims( + class_weights_tens, [0] * (1 + n_dims) + ) self.built = True super(CrossEntropyLoss, self).build(input_shape) @@ -1490,30 +1907,58 @@ def call(self, inputs, **kwargs): # make sure tensors are probabilistic gt = inputs[0] pred = inputs[1] - if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps - gt = K.clip(gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()), 0, 1) - pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) - pred = K.clip(pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon()) # to avoid log(0) + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = K.clip( + gt + / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ), + 0, + 1, + ) + pred = pred / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) + pred = K.clip( + pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon() + ) # to avoid log(0) # compare prediction/target, ce has the same shape has the input tensors ce = -gt * tf.math.log(pred) # apply boundary weighting (ie voxels close to region boundaries will be counted several times to compute Dice) if self.boundary_weights: - avg = self.avg_pooling_layer(pool_size=2 * self.boundary_dist + 1, strides=1, padding='same')(gt) - boundaries = tf.cast(avg > 0., 'float32') * tf.cast(avg < (1 / len(self.spatial_axes) - 1e-4), 'float32') + avg = self.avg_pooling_layer( + pool_size=2 * self.boundary_dist + 1, strides=1, padding="same" + )(gt) + boundaries = tf.cast(avg > 0.0, "float32") * tf.cast( + avg < (1 / len(self.spatial_axes) - 1e-4), "float32" + ) if self.skip_background: boundaries_channels = tf.unstack(boundaries, axis=-1) - boundaries = tf.stack([tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], axis=-1) + boundaries = tf.stack( + [tf.zeros_like(boundaries_channels[0])] + boundaries_channels[1:], + axis=-1, + ) boundary_weights_tensor = 1 + self.boundary_weights * boundaries ce *= boundary_weights_tensor else: boundary_weights_tensor = None # apply class weighting across labels. By the end of this, ce still has the same shape has the input tensors. - if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt - if boundary_weights_tensor is not None: # we account for the boundary weighting to compute volume - self.class_weights_tens = 1 / tf.reduce_sum(gt * boundary_weights_tensor, self.spatial_axes, True) + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt + if ( + boundary_weights_tensor is not None + ): # we account for the boundary weighting to compute volume + self.class_weights_tens = 1 / tf.reduce_sum( + gt * boundary_weights_tensor, self.spatial_axes, True + ) else: self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) if self.class_weights_tens is not None: @@ -1560,8 +2005,12 @@ def get_config(self): def build(self, input_shape): # get shape - assert len(input_shape) == 2, 'MomentLoss expects 2 inputs to compute the Dice loss.' - assert input_shape[0] == input_shape[1], 'the two inputs must have the same shape.' + assert ( + len(input_shape) == 2 + ), "MomentLoss expects 2 inputs to compute the Dice loss." + assert ( + input_shape[0] == input_shape[1] + ), "the two inputs must have the same shape." inshape = input_shape[0][1:] n_dims = len(inshape[:-1]) n_labels = inshape[-1] @@ -1569,15 +2018,20 @@ def build(self, input_shape): # build coordinate meshgrid of size (1, dim1, dim2, ..., dimN, ndim, nchan) self.coordinates = tf.stack(nrn_utils.volshape_to_ndgrid(inshape[:-1]), -1) - self.coordinates = tf.cast(l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), 'float32') + self.coordinates = tf.cast( + l2i_et.expand_dims(tf.stack([self.coordinates] * n_labels, -1), 0), + "float32", + ) # build tensor with class weights if self.class_weights is not None: if self.class_weights == -1: self.dynamic_weighting = True else: - class_weights_tens = utils.reformat_to_list(self.class_weights, n_labels) - class_weights_tens = tf.convert_to_tensor(class_weights_tens, 'float32') + class_weights_tens = utils.reformat_to_list( + self.class_weights, n_labels + ) + class_weights_tens = tf.convert_to_tensor(class_weights_tens, "float32") self.class_weights_tens = l2i_et.expand_dims(class_weights_tens, 0) self.built = True @@ -1588,17 +2042,31 @@ def call(self, inputs, **kwargs): # make sure tensors are probabilistic gt = inputs[0] # (B, dim1, dim2, ..., dimN, nchan) pred = inputs[1] - if self.enable_checks: # disabling is useful to, e.g., use incomplete label maps - gt = gt / (tf.math.reduce_sum(gt, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) - pred = pred / (tf.math.reduce_sum(pred, axis=-1, keepdims=True) + tf.keras.backend.epsilon()) + if ( + self.enable_checks + ): # disabling is useful to, e.g., use incomplete label maps + gt = gt / ( + tf.math.reduce_sum(gt, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) + pred = pred / ( + tf.math.reduce_sum(pred, axis=-1, keepdims=True) + + tf.keras.backend.epsilon() + ) # compute loss gt_mean_coordinates = self._mean_coordinates(gt) # (B, ndim, nchan) pred_mean_coordinates = self._mean_coordinates(pred) - loss = tf.math.sqrt(tf.reduce_sum(tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1)) # (B, nchan) + loss = tf.math.sqrt( + tf.reduce_sum( + tf.square(pred_mean_coordinates - gt_mean_coordinates), axis=1 + ) + ) # (B, nchan) # apply class weighting across labels. In this case loss will have shape (batch), otherwise (batch, n_labels). - if self.dynamic_weighting: # the weight of a class is the inverse of its volume in the gt + if ( + self.dynamic_weighting + ): # the weight of a class is the inverse of its volume in the gt self.class_weights_tens = 1 / tf.reduce_sum(gt, self.spatial_axes) if self.class_weights_tens is not None: self.class_weights_tens /= tf.reduce_sum(self.class_weights_tens, -1) @@ -1607,9 +2075,15 @@ def call(self, inputs, **kwargs): return tf.math.reduce_mean(loss) def _mean_coordinates(self, tensor): - tensor = l2i_et.expand_dims(tensor, axis=-2) # (B, dim1, dim2, ..., dimN, 1, nchan) - numerator = tf.reduce_sum(tensor * self.coordinates, axis=self.spatial_axes) # (B, ndim, nchan) - denominator = tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon() + tensor = l2i_et.expand_dims( + tensor, axis=-2 + ) # (B, dim1, dim2, ..., dimN, 1, nchan) + numerator = tf.reduce_sum( + tensor * self.coordinates, axis=self.spatial_axes + ) # (B, ndim, nchan) + denominator = ( + tf.reduce_sum(tensor, axis=self.spatial_axes) + tf.keras.backend.epsilon() + ) return numerator / denominator def compute_output_shape(self, input_shape): @@ -1633,7 +2107,9 @@ class ResetValuesToZero(Layer): """ def __init__(self, values, **kwargs): - assert values is not None, 'please provide correct list of values, received None' + assert ( + values is not None + ), "please provide correct list of values, received None" self.values = utils.reformat_to_list(values) self.values_tens = None self.n_values = len(values) @@ -1652,7 +2128,9 @@ def build(self, input_shape): def call(self, inputs, **kwargs): values = tf.cast(self.values_tens, dtype=inputs.dtype) for i in range(self.n_values): - inputs = tf.where(tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs) + inputs = tf.where( + tf.equal(inputs, values[i]), tf.zeros_like(inputs), inputs + ) return inputs @@ -1681,12 +2159,15 @@ def get_config(self): return config def build(self, input_shape): - self.lut = tf.convert_to_tensor(utils.get_mapping_lut(self.source_values, dest=self.dest_values), dtype='int32') + self.lut = tf.convert_to_tensor( + utils.get_mapping_lut(self.source_values, dest=self.dest_values), + dtype="int32", + ) self.built = True super(ConvertLabels, self).build(input_shape) def call(self, inputs, **kwargs): - return tf.gather(self.lut, tf.cast(inputs, dtype='int32')) + return tf.gather(self.lut, tf.cast(inputs, dtype="int32")) class PadAroundCentre(Layer): @@ -1725,19 +2206,32 @@ def build(self, input_shape): shape[-1] = 0 if self.pad_margin is not None: - assert self.pad_shape is None, 'please do not provide a padding shape and margin at the same time.' + assert ( + self.pad_shape is None + ), "please do not provide a padding shape and margin at the same time." # reformat padding margins - pad = np.transpose(np.array([[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] * 2)) - self.pad_margin_tens = tf.convert_to_tensor(pad, dtype='int32') + pad = np.transpose( + np.array( + [[0] + utils.reformat_to_list(self.pad_margin, self.n_dims) + [0]] + * 2 + ) + ) + self.pad_margin_tens = tf.convert_to_tensor(pad, dtype="int32") elif self.pad_shape is not None: - assert self.pad_margin is None, 'please do not provide a padding shape and margin at the same time.' + assert ( + self.pad_margin is None + ), "please do not provide a padding shape and margin at the same time." # pad shape - tensor_shape = tf.cast(tf.convert_to_tensor(shape), 'int32') - self.pad_shape_tens = np.array([0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0]) - self.pad_shape_tens = tf.convert_to_tensor(self.pad_shape_tens, dtype='int32') + tensor_shape = tf.cast(tf.convert_to_tensor(shape), "int32") + self.pad_shape_tens = np.array( + [0] + utils.reformat_to_list(self.pad_shape, length=self.n_dims) + [0] + ) + self.pad_shape_tens = tf.convert_to_tensor( + self.pad_shape_tens, dtype="int32" + ) self.pad_shape_tens = tf.math.maximum(tensor_shape, self.pad_shape_tens) # padding margin @@ -1746,13 +2240,17 @@ def build(self, input_shape): self.pad_margin_tens = tf.stack([min_margins, max_margins], axis=-1) else: - raise Exception('please either provide a padding shape or a padding margin.') + raise Exception( + "please either provide a padding shape or a padding margin." + ) self.built = True super(PadAroundCentre, self).build(input_shape) def call(self, inputs, **kwargs): - return tf.pad(inputs, self.pad_margin_tens, mode='CONSTANT', constant_values=self.value) + return tf.pad( + inputs, self.pad_margin_tens, mode="CONSTANT", constant_values=self.value + ) class MaskEdges(Layer): @@ -1796,8 +2294,10 @@ class MaskEdges(Layer): """ def __init__(self, axes, boundaries, prob_mask=1, **kwargs): - self.axes = utils.reformat_to_list(axes, dtype='int') - self.boundaries = utils.reformat_to_n_channels_array(boundaries, n_dims=4, n_channels=len(self.axes)) + self.axes = utils.reformat_to_list(axes, dtype="int") + self.boundaries = utils.reformat_to_n_channels_array( + boundaries, n_dims=4, n_channels=len(self.axes) + ) self.prob_mask = prob_mask self.inputshape = None super(MaskEdges, self).__init__(**kwargs) @@ -1822,26 +2322,42 @@ def call(self, inputs, **kwargs): # select restricting indices axis_boundaries = self.boundaries[i, :] - idx1 = tf.math.round(tf.random.uniform([1], - minval=axis_boundaries[0] * self.inputshape[axis], - maxval=axis_boundaries[1] * self.inputshape[axis])) - idx2 = tf.math.round(tf.random.uniform([1], - minval=axis_boundaries[2] * self.inputshape[axis], - maxval=axis_boundaries[3] * self.inputshape[axis] - 1) - idx1) + idx1 = tf.math.round( + tf.random.uniform( + [1], + minval=axis_boundaries[0] * self.inputshape[axis], + maxval=axis_boundaries[1] * self.inputshape[axis], + ) + ) + idx2 = tf.math.round( + tf.random.uniform( + [1], + minval=axis_boundaries[2] * self.inputshape[axis], + maxval=axis_boundaries[3] * self.inputshape[axis] - 1, + ) + - idx1 + ) idx3 = self.inputshape[axis] - idx1 - idx2 - split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype='int32') + split_idx = tf.cast(tf.concat([idx1, idx2, idx3], axis=0), dtype="int32") # update mask split_list = tf.split(inputs, split_idx, axis=axis) - tmp_mask = tf.concat([tf.zeros_like(split_list[0]), - tf.ones_like(split_list[1]), - tf.zeros_like(split_list[2])], axis=axis) + tmp_mask = tf.concat( + [ + tf.zeros_like(split_list[0]), + tf.ones_like(split_list[1]), + tf.zeros_like(split_list[2]), + ], + axis=axis, + ) mask = mask * tmp_mask # mask second_channel - tensor = K.switch(tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), - inputs * mask, - inputs) + tensor = K.switch( + tf.squeeze(K.greater(tf.random.uniform([1], 0, 1), 1 - self.prob_mask)), + inputs * mask, + inputs, + ) return [tensor, mask] @@ -1851,11 +2367,15 @@ def compute_output_shape(self, input_shape): class ImageGradients(Layer): - def __init__(self, gradient_type='sobel', return_magnitude=False, **kwargs): + def __init__(self, gradient_type="sobel", return_magnitude=False, **kwargs): self.gradient_type = gradient_type - assert (self.gradient_type == 'sobel') | (self.gradient_type == '1-step_diff'), \ - 'gradient_type should be either sobel or 1-step_diff, had %s' % self.gradient_type + assert (self.gradient_type == "sobel") | ( + self.gradient_type == "1-step_diff" + ), ( + "gradient_type should be either sobel or 1-step_diff, had %s" + % self.gradient_type + ) # shape self.n_dims = 0 @@ -1885,10 +2405,10 @@ def build(self, input_shape): self.n_channels = input_shape[-1] # prepare kernel if sobel gradients - if self.gradient_type == 'sobel': + if self.gradient_type == "sobel": self.kernels = l2i_et.sobel_kernels(self.n_dims) self.stride = [1] * (self.n_dims + 2) - self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) else: self.kernels = self.convnd = self.stride = None @@ -1902,14 +2422,24 @@ def call(self, inputs, **kwargs): gradients = list() # sobel method - if self.gradient_type == 'sobel': + if self.gradient_type == "sobel": # get sobel gradients in each direction for n in range(self.n_dims): gradient = image # apply 1D kernel in each direction (sobel kernels are separable), instead of applying a nD kernel for k in self.kernels[n]: - gradient = tf.concat([self.convnd(tf.expand_dims(gradient[..., n], -1), k, self.stride, 'SAME') - for n in range(self.n_channels)], -1) + gradient = tf.concat( + [ + self.convnd( + tf.expand_dims(gradient[..., n], -1), + k, + self.stride, + "SAME", + ) + for n in range(self.n_channels) + ], + -1, + ) gradients.append(gradient) # 1-step method, only supports 2 and 3D @@ -1926,18 +2456,28 @@ def call(self, inputs, **kwargs): gradients.append(image[:, :, :, 1:, :] - image[:, :, :, :-1, :]) # dz else: - raise Exception('ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD' % self.n_dims) + raise Exception( + "ImageGradients only support 2D or 3D tensors for 1-step diff, had: %dD" + % self.n_dims + ) # pad with zeros to return tensors of the same shape as input for i in range(self.n_dims): tmp_shape = list(self.shape) tmp_shape[i] = 1 - zeros = tf.zeros(tf.concat([batchsize, tf.convert_to_tensor(tmp_shape, dtype='int32')], 0), image.dtype) + zeros = tf.zeros( + tf.concat( + [batchsize, tf.convert_to_tensor(tmp_shape, dtype="int32")], 0 + ), + image.dtype, + ) gradients[i] = tf.concat([gradients[i], zeros], axis=i + 1) # compute total gradient magnitude if necessary, or concatenate different gradients along the channel axis if self.return_magnitude: - gradients = tf.sqrt(tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1)) + gradients = tf.sqrt( + tf.reduce_sum(tf.square(tf.stack(gradients, axis=-1)), axis=-1) + ) else: gradients = tf.concat(gradients, axis=-1) @@ -1965,12 +2505,22 @@ class RandomDilationErosion(Layer): choice to either return the eroded label map or the mask (return_mask=True) """ - def __init__(self, min_factor, max_factor, max_factor_dilate=None, prob=1, operation='random', return_mask=False, - **kwargs): + def __init__( + self, + min_factor, + max_factor, + max_factor_dilate=None, + prob=1, + operation="random", + return_mask=False, + **kwargs + ): self.min_factor = min_factor self.max_factor = max_factor - self.max_factor_dilate = max_factor_dilate if max_factor_dilate is not None else self.max_factor + self.max_factor_dilate = ( + max_factor_dilate if max_factor_dilate is not None else self.max_factor + ) self.prob = prob self.operation = operation self.return_mask = return_mask @@ -1998,7 +2548,7 @@ def build(self, input_shape): self.n_channels = self.inshape[-1] # prepare convolution - self.convnd = getattr(tf.nn, 'conv%dd' % self.n_dims) + self.convnd = getattr(tf.nn, "conv%dd" % self.n_dims) self.built = True super(RandomDilationErosion, self).build(input_shape) @@ -2007,30 +2557,42 @@ def call(self, inputs, **kwargs): # sample probability of applying operation. If random negative is erosion and positive is dilation batchsize = tf.split(tf.shape(inputs), [1, -1])[0] - shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype='int32')], axis=0) - if self.operation == 'dilation': + shape = tf.concat([batchsize, tf.convert_to_tensor([1], dtype="int32")], axis=0) + if self.operation == "dilation": prob = tf.random.uniform(shape, 0, 1) - elif self.operation == 'erosion': + elif self.operation == "erosion": prob = tf.random.uniform(shape, -1, 0) - elif self.operation == 'random': + elif self.operation == "random": prob = tf.random.uniform(shape, -1, 1) else: - raise ValueError("operation should either be 'dilation' 'erosion' or 'random', had %s" % self.operation) + raise ValueError( + "operation should either be 'dilation' 'erosion' or 'random', had %s" + % self.operation + ) # build kernel if self.min_factor == self.max_factor: - dist_threshold = self.min_factor * tf.ones(shape, dtype='int32') + dist_threshold = self.min_factor * tf.ones(shape, dtype="int32") else: - if (self.max_factor == self.max_factor_dilate) | (self.operation != 'random'): - dist_threshold = tf.random.uniform(shape, minval=self.min_factor, maxval=self.max_factor, dtype='int32') + if (self.max_factor == self.max_factor_dilate) | ( + self.operation != "random" + ): + dist_threshold = tf.random.uniform( + shape, minval=self.min_factor, maxval=self.max_factor, dtype="int32" + ) else: - dist_threshold = tf.cast(tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), dtype='int32') - kernel = l2i_et.unit_kernel(dist_threshold, self.n_dims, max_dist_threshold=self.max_factor) + dist_threshold = tf.cast( + tf.map_fn(self._sample_factor, [prob], dtype=tf.float32), + dtype="int32", + ) + kernel = l2i_et.unit_kernel( + dist_threshold, self.n_dims, max_dist_threshold=self.max_factor + ) # convolve input mask with kernel according to given probability - mask = tf.cast(tf.cast(inputs, dtype='bool'), dtype='float32') + mask = tf.cast(tf.cast(inputs, dtype="bool"), dtype="float32") mask = tf.map_fn(self._single_blur, [mask, kernel, prob], dtype=tf.float32) - mask = tf.cast(mask, 'bool') + mask = tf.cast(mask, "bool") if self.return_mask: return mask @@ -2038,22 +2600,61 @@ def call(self, inputs, **kwargs): return inputs * tf.cast(mask, dtype=inputs.dtype) def _sample_factor(self, inputs): - return tf.cast(K.switch(K.less(tf.squeeze(inputs[0]), 0), - tf.random.uniform((1,), self.min_factor, self.max_factor, dtype='int32'), - tf.random.uniform((1,), self.min_factor, self.max_factor_dilate, dtype='int32')), - dtype='float32') + return tf.cast( + K.switch( + K.less(tf.squeeze(inputs[0]), 0), + tf.random.uniform( + (1,), self.min_factor, self.max_factor, dtype="int32" + ), + tf.random.uniform( + (1,), self.min_factor, self.max_factor_dilate, dtype="int32" + ), + ), + dtype="float32", + ) def _single_blur(self, inputs): # dilate... - new_mask = K.switch(K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001), - tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(inputs[0], 0), inputs[1], - [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), - inputs[0]) + new_mask = K.switch( + K.greater(tf.squeeze(inputs[2]), 1 - self.prob + 0.001), + tf.cast( + tf.greater( + tf.squeeze( + self.convnd( + tf.expand_dims(inputs[0], 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ), + axis=0, + ), + 0.01, + ), + dtype="float32", + ), + inputs[0], + ) # ...or erode - new_mask = K.switch(K.less(tf.squeeze(inputs[2]), - (1 - self.prob + 0.001)), - 1 - tf.cast(tf.greater(tf.squeeze(self.convnd(tf.expand_dims(1 - new_mask, 0), inputs[1], - [1] * (self.n_dims + 2), padding='SAME'), axis=0), 0.01), dtype='float32'), - new_mask) + new_mask = K.switch( + K.less(tf.squeeze(inputs[2]), -(1 - self.prob + 0.001)), + 1 + - tf.cast( + tf.greater( + tf.squeeze( + self.convnd( + tf.expand_dims(1 - new_mask, 0), + inputs[1], + [1] * (self.n_dims + 2), + padding="SAME", + ), + axis=0, + ), + 0.01, + ), + dtype="float32", + ), + new_mask, + ) return new_mask def compute_output_shape(self, input_shape): diff --git a/nobrainer/ext/lab2im/utils.py b/nobrainer/ext/lab2im/utils.py index 1f3b8888..66e5c03f 100644 --- a/nobrainer/ext/lab2im/utils.py +++ b/nobrainer/ext/lab2im/utils.py @@ -55,20 +55,19 @@ License. """ - -import os +from datetime import timedelta import glob import math -import time +import os import pickle -import numpy as np -import nibabel as nib -import tensorflow as tf -import keras.layers as KL +import time + import keras.backend as K -from datetime import timedelta +import keras.layers as KL +import nibabel as nib +import numpy as np from scipy.ndimage.morphology import distance_transform_edt - +import tensorflow as tf # ---------------------------------------------- loading/saving functions ---------------------------------------------- @@ -86,9 +85,11 @@ def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=Non The returned affine matrix is also given in this new space. Must be a numpy array of dimension 4x4. :return: the volume, with corresponding affine matrix and header if im_only is False. """ - assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume + assert path_volume.endswith((".nii", ".nii.gz", ".mgz", ".npz")), ( + "Unknown data file: %s" % path_volume + ) - if path_volume.endswith(('.nii', '.nii.gz', '.mgz')): + if path_volume.endswith((".nii", ".nii.gz", ".mgz")): x = nib.load(path_volume) if squeeze: volume = np.squeeze(x.get_fdata()) @@ -97,21 +98,26 @@ def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=Non aff = x.affine header = x.header else: # npz - volume = np.load(path_volume)['vol_data'] + volume = np.load(path_volume)["vol_data"] if squeeze: volume = np.squeeze(volume) aff = np.eye(4) header = nib.Nifti1Header() if dtype is not None: - if 'int' in dtype: + if "int" in dtype: volume = np.round(volume) volume = volume.astype(dtype=dtype) # align image to reference affine matrix if aff_ref is not None: - from nobrainer.ext.lab2im import edit_volumes # the import is done here to avoid import loops + from nobrainer.ext.lab2im import ( # the import is done here to avoid import loops + edit_volumes, + ) + n_dims, _ = get_dims(list(volume.shape), max_channels=10) - volume, aff = edit_volumes.align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims) + volume, aff = edit_volumes.align_volume_to_ref( + volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims + ) if im_only: return volume @@ -134,18 +140,20 @@ def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): """ mkdir(os.path.dirname(path)) - if '.npz' in path: + if ".npz" in path: np.savez_compressed(path, vol_data=volume) else: if header is None: header = nib.Nifti1Header() if isinstance(aff, str): - if aff == 'FS': - aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) + if aff == "FS": + aff = np.array( + [[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]] + ) elif aff is None: aff = np.eye(4) if dtype is not None: - if 'int' in dtype: + if "int" in dtype: volume = np.round(volume) volume = volume.astype(dtype=dtype) nifty = nib.Nifti1Image(volume, aff, header) @@ -180,16 +188,19 @@ def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels im_shape = im_shape[:n_dims] # get labels res - if '.nii' in path_volume: - data_res = np.array(header['pixdim'][1:n_dims + 1]) - elif '.mgz' in path_volume: - data_res = np.array(header['delta']) # mgz image + if ".nii" in path_volume: + data_res = np.array(header["pixdim"][1 : n_dims + 1]) + elif ".mgz" in path_volume: + data_res = np.array(header["delta"]) # mgz image else: data_res = np.array([1.0] * n_dims) # align to given affine matrix if aff_ref is not None: - from nobrainer.ext.lab2im import edit_volumes # the import is done here to avoid import loops + from nobrainer.ext.lab2im import ( # the import is done here to avoid import loops + edit_volumes, + ) + ras_axes = edit_volumes.get_ras_axes(aff, n_dims=n_dims) ras_axes_ref = edit_volumes.get_ras_axes(aff_ref, n_dims=n_dims) im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims) @@ -206,7 +217,9 @@ def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels return im_shape, aff, n_dims, n_channels, header, data_res -def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_sort=False): +def get_list_labels( + label_list=None, labels_dir=None, save_label_list=None, FS_sort=False +): """This function reads or computes a list of all label values used in a set of label maps. It can also sort all labels according to FreeSurfer lut. :param label_list: (optional) already computed label_list. Can be a sequence, a 1d numpy array, or the path to @@ -224,32 +237,104 @@ def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_s # load label list if previously computed if label_list is not None: - label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int')) + label_list = np.array( + reformat_to_list(label_list, load_as_numpy=True, dtype="int") + ) # compute label list from all label files elif labels_dir is not None: - print('Compiling list of unique labels') + print("Compiling list of unique labels") # go through all labels files and compute unique list of labels labels_paths = list_images_in_folder(labels_dir) label_list = np.empty(0) - loop_info = LoopInfo(len(labels_paths), 10, 'processing', print_time=True) + loop_info = LoopInfo(len(labels_paths), 10, "processing", print_time=True) for lab_idx, path in enumerate(labels_paths): loop_info.update(lab_idx) - y = load_volume(path, dtype='int32') + y = load_volume(path, dtype="int32") y_unique = np.unique(y) - label_list = np.unique(np.concatenate((label_list, y_unique))).astype('int') + label_list = np.unique(np.concatenate((label_list, y_unique))).astype("int") else: - raise Exception('either label_list, path_label_list or labels_dir should be provided') + raise Exception( + "either label_list, path_label_list or labels_dir should be provided" + ) # sort labels in neutral/left/right according to FS labels n_neutral_labels = 0 if FS_sort: - neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108, - 109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, - 251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, - 502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530, - 531, 532, 533, 534, 535, 536, 537] + neutral_FS_labels = [ + 0, + 14, + 15, + 16, + 21, + 22, + 23, + 24, + 72, + 77, + 80, + 85, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 165, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 251, + 252, + 253, + 254, + 255, + 258, + 259, + 260, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 502, + 506, + 507, + 508, + 509, + 511, + 512, + 514, + 515, + 516, + 517, + 530, + 531, + 532, + 533, + 534, + 535, + 536, + 537, + ] neutral = list() left = list() right = list() @@ -257,19 +342,36 @@ def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_s if la in neutral_FS_labels: if la not in neutral: neutral.append(la) - elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \ - (la == 865) | (20100 < la < 20110): + elif ( + (0 < la < 14) + | (16 < la < 21) + | (24 < la < 40) + | (135 < la < 139) + | (1000 <= la <= 1035) + | (la == 865) + | (20100 < la < 20110) + ): if la not in left: left.append(la) - elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \ - (la == 866): + elif ( + (39 < la < 72) + | (162 < la < 165) + | (2000 <= la <= 2035) + | (20000 < la < 20010) + | (la == 139) + | (la == 866) + ): if la not in right: right.append(la) else: - raise Exception('label {} not in our current FS classification, ' - 'please update get_list_labels in utils.py'.format(la)) + raise Exception( + "label {} not in our current FS classification, " + "please update get_list_labels in utils.py".format(la) + ) label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)]) - if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)): + if ((len(left) > 0) & (len(right) > 0)) | ( + (len(left) == 0) & (len(right) == 0) + ): n_neutral_labels = len(neutral) else: n_neutral_labels = len(label_list) @@ -288,29 +390,29 @@ def load_array_if_path(var, load_as_numpy=True): """If var is a string and load_as_numpy is True, this function loads the array writen at the path indicated by var. Otherwise it simply returns var as it is.""" if (isinstance(var, str)) & load_as_numpy: - assert os.path.isfile(var), 'No such path: %s' % var + assert os.path.isfile(var), "No such path: %s" % var var = np.load(var) return var def write_pickle(filepath, obj): - """ write a python object with a pickle at a given path""" - with open(filepath, 'wb') as file: + """write a python object with a pickle at a given path""" + with open(filepath, "wb") as file: pickler = pickle.Pickler(file) pickler.dump(obj) def read_pickle(filepath): - """ read a python object with a pickle""" - with open(filepath, 'rb') as file: + """read a python object with a pickle""" + with open(filepath, "rb") as file: unpickler = pickle.Unpickler(file) return unpickler.load() -def write_model_summary(model, filepath='./model_summary.txt', line_length=150): +def write_model_summary(model, filepath="./model_summary.txt", line_length=150): """Write the summary of a keras model at a given path, with a given length for each line""" - with open(filepath, 'w') as fh: - model.summary(print_fn=lambda x: fh.write(x + '\n'), line_length=line_length) + with open(filepath, "w") as fh: + model.summary(print_fn=lambda x: fh.write(x + "\n"), line_length=line_length) # ----------------------------------------------- reformatting functions ----------------------------------------------- @@ -350,23 +452,29 @@ def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): if len(var) == 1: var = var * length elif len(var) != length: - raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, ' - 'had {1}'.format(length, var)) + raise ValueError( + "if var is a list/tuple/numpy array, it should be of length 1 or {0}, " + "had {1}".format(length, var) + ) else: - raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array') + raise TypeError( + "var should be an int, float, tuple, list, numpy array, or path to numpy array" + ) # convert items type if dtype is not None: - if dtype == 'int': + if dtype == "int": var = [int(v) for v in var] - elif dtype == 'float': + elif dtype == "float": var = [float(v) for v in var] - elif dtype == 'bool': + elif dtype == "bool": var = [bool(v) for v in var] - elif dtype == 'str': + elif dtype == "str": var = [str(v) for v in var] else: - raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype)) + raise ValueError( + "dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype) + ) return var @@ -391,9 +499,13 @@ def reformat_to_n_channels_array(var, n_dims=3, n_channels=1): if np.squeeze(var).shape == (n_dims,): var = np.tile(var.reshape((1, n_dims)), (n_channels, 1)) elif var.shape != (n_channels, n_dims): - raise ValueError('if array, var should be {0} or {1}'.format((1, n_dims), (n_channels, n_dims))) + raise ValueError( + "if array, var should be {0} or {1}".format( + (1, n_dims), (n_channels, n_dims) + ) + ) else: - raise TypeError('var should be int, float, list, tuple or ndarray') + raise TypeError("var should be int, float, list, tuple or ndarray") return np.round(var, 3) @@ -403,24 +515,32 @@ def reformat_to_n_channels_array(var, n_dims=3, n_channels=1): def list_images_in_folder(path_dir, include_single_image=True, check_if_empty=True): """List all files with extension nii, nii.gz, mgz, or npz within a folder.""" basename = os.path.basename(path_dir) - if include_single_image & \ - (('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename)): - assert os.path.isfile(path_dir), 'file %s does not exist' % path_dir + if include_single_image & ( + (".nii.gz" in basename) + | (".nii" in basename) + | (".mgz" in basename) + | (".npz" in basename) + ): + assert os.path.isfile(path_dir), "file %s does not exist" % path_dir list_images = [path_dir] else: if os.path.isdir(path_dir): - list_images = sorted(glob.glob(os.path.join(path_dir, '*nii.gz')) + - glob.glob(os.path.join(path_dir, '*nii')) + - glob.glob(os.path.join(path_dir, '*.mgz')) + - glob.glob(os.path.join(path_dir, '*.npz'))) + list_images = sorted( + glob.glob(os.path.join(path_dir, "*nii.gz")) + + glob.glob(os.path.join(path_dir, "*nii")) + + glob.glob(os.path.join(path_dir, "*.mgz")) + + glob.glob(os.path.join(path_dir, "*.npz")) + ) else: - raise Exception('Folder does not exist: %s' % path_dir) + raise Exception("Folder does not exist: %s" % path_dir) if check_if_empty: - assert len(list_images) > 0, 'no .nii, .nii.gz, .mgz or .npz image could be found in %s' % path_dir + assert len(list_images) > 0, ( + "no .nii, .nii.gz, .mgz or .npz image could be found in %s" % path_dir + ) return list_images -def list_files(path_dir, whole_path=True, expr=None, cond_type='or'): +def list_files(path_dir, whole_path=True, expr=None, cond_type="or"): """This function returns a list of files contained in a folder, with possible regexp. :param path_dir: path of a folder :param whole_path: (optional) whether to return whole path or just the filenames. @@ -430,31 +550,46 @@ def list_files(path_dir, whole_path=True, expr=None, cond_type='or'): :return: a list of files """ assert isinstance(whole_path, bool), "whole_path should be bool" - assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'" + assert cond_type in ["or", "and"], "cond_type should be either 'or', or 'and'" if whole_path: - files_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir) - if os.path.isfile(os.path.join(path_dir, f))]) + files_list = sorted( + [ + os.path.join(path_dir, f) + for f in os.listdir(path_dir) + if os.path.isfile(os.path.join(path_dir, f)) + ] + ) else: - files_list = sorted([f for f in os.listdir(path_dir) if os.path.isfile(os.path.join(path_dir, f))]) + files_list = sorted( + [ + f + for f in os.listdir(path_dir) + if os.path.isfile(os.path.join(path_dir, f)) + ] + ) if expr is not None: # assumed to be either str or list of str if isinstance(expr, str): expr = [expr] elif not isinstance(expr, (list, tuple)): - raise Exception("if specified, 'expr' should be a string or list of strings.") + raise Exception( + "if specified, 'expr' should be a string or list of strings." + ) matched_list_files = list() for match in expr: - tmp_matched_files_list = sorted([f for f in files_list if match in os.path.basename(f)]) - if cond_type == 'or': + tmp_matched_files_list = sorted( + [f for f in files_list if match in os.path.basename(f)] + ) + if cond_type == "or": files_list = [f for f in files_list if f not in tmp_matched_files_list] matched_list_files += tmp_matched_files_list - elif cond_type == 'and': + elif cond_type == "and": files_list = tmp_matched_files_list matched_list_files = tmp_matched_files_list files_list = sorted(matched_list_files) return files_list -def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'): +def list_subfolders(path_dir, whole_path=True, expr=None, cond_type="or"): """This function returns a list of subfolders contained in a folder, with possible regexp. :param path_dir: path of a folder :param whole_path: (optional) whether to return whole path or just the subfolder names. @@ -464,24 +599,41 @@ def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'): :return: a list of subfolders """ assert isinstance(whole_path, bool), "whole_path should be bool" - assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'" + assert cond_type in ["or", "and"], "cond_type should be either 'or', or 'and'" if whole_path: - subdirs_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir) - if os.path.isdir(os.path.join(path_dir, f))]) + subdirs_list = sorted( + [ + os.path.join(path_dir, f) + for f in os.listdir(path_dir) + if os.path.isdir(os.path.join(path_dir, f)) + ] + ) else: - subdirs_list = sorted([f for f in os.listdir(path_dir) if os.path.isdir(os.path.join(path_dir, f))]) + subdirs_list = sorted( + [ + f + for f in os.listdir(path_dir) + if os.path.isdir(os.path.join(path_dir, f)) + ] + ) if expr is not None: # assumed to be either str or list of str if isinstance(expr, str): expr = [expr] elif not isinstance(expr, (list, tuple)): - raise Exception("if specified, 'expr' should be a string or list of strings.") + raise Exception( + "if specified, 'expr' should be a string or list of strings." + ) matched_list_subdirs = list() for match in expr: - tmp_matched_list_subdirs = sorted([f for f in subdirs_list if match in os.path.basename(f)]) - if cond_type == 'or': - subdirs_list = [f for f in subdirs_list if f not in tmp_matched_list_subdirs] + tmp_matched_list_subdirs = sorted( + [f for f in subdirs_list if match in os.path.basename(f)] + ) + if cond_type == "or": + subdirs_list = [ + f for f in subdirs_list if f not in tmp_matched_list_subdirs + ] matched_list_subdirs += tmp_matched_list_subdirs - elif cond_type == 'and': + elif cond_type == "and": subdirs_list = tmp_matched_list_subdirs matched_list_subdirs = tmp_matched_list_subdirs subdirs_list = sorted(matched_list_subdirs) @@ -490,53 +642,58 @@ def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'): def get_image_extension(path): name = os.path.basename(path) - if name[-7:] == '.nii.gz': - return 'nii.gz' - elif name[-4:] == '.mgz': - return 'mgz' - elif name[-4:] == '.nii': - return 'nii' - elif name[-4:] == '.npz': - return 'npz' + if name[-7:] == ".nii.gz": + return "nii.gz" + elif name[-4:] == ".mgz": + return "mgz" + elif name[-4:] == ".nii": + return "nii" + elif name[-4:] == ".npz": + return "npz" def strip_extension(path): """Strip classical image extensions (.nii.gz, .nii, .mgz, .npz) from a filename.""" - return path.replace('.nii.gz', '').replace('.nii', '').replace('.mgz', '').replace('.npz', '') + return ( + path.replace(".nii.gz", "") + .replace(".nii", "") + .replace(".mgz", "") + .replace(".npz", "") + ) def strip_suffix(path): """Strip classical image suffix from a filename.""" - path = path.replace('_aseg', '') - path = path.replace('aseg', '') - path = path.replace('.aseg', '') - path = path.replace('_aseg_1', '') - path = path.replace('_aseg_2', '') - path = path.replace('aseg_1_', '') - path = path.replace('aseg_2_', '') - path = path.replace('_orig', '') - path = path.replace('orig', '') - path = path.replace('.orig', '') - path = path.replace('_norm', '') - path = path.replace('norm', '') - path = path.replace('.norm', '') - path = path.replace('_talairach', '') - path = path.replace('GSP_FS_4p5', 'GSP') - path = path.replace('.nii_crispSegmentation', '') - path = path.replace('_crispSegmentation', '') - path = path.replace('_seg', '') - path = path.replace('.seg', '') - path = path.replace('seg', '') - path = path.replace('_seg_1', '') - path = path.replace('_seg_2', '') - path = path.replace('seg_1_', '') - path = path.replace('seg_2_', '') + path = path.replace("_aseg", "") + path = path.replace("aseg", "") + path = path.replace(".aseg", "") + path = path.replace("_aseg_1", "") + path = path.replace("_aseg_2", "") + path = path.replace("aseg_1_", "") + path = path.replace("aseg_2_", "") + path = path.replace("_orig", "") + path = path.replace("orig", "") + path = path.replace(".orig", "") + path = path.replace("_norm", "") + path = path.replace("norm", "") + path = path.replace(".norm", "") + path = path.replace("_talairach", "") + path = path.replace("GSP_FS_4p5", "GSP") + path = path.replace(".nii_crispSegmentation", "") + path = path.replace("_crispSegmentation", "") + path = path.replace("_seg", "") + path = path.replace(".seg", "") + path = path.replace("seg", "") + path = path.replace("_seg_1", "") + path = path.replace("_seg_2", "") + path = path.replace("seg_1_", "") + path = path.replace("seg_2_", "") return path def mkdir(path_dir): """Recursively creates the current dir as well as its parent folders if they do not already exist.""" - if path_dir[-1] == '/': + if path_dir[-1] == "/": path_dir = path_dir[:-1] if not os.path.isdir(path_dir): list_dir_to_create = [path_dir] @@ -549,7 +706,7 @@ def mkdir(path_dir): def mkcmd(*args): """Creates terminal command with provided inputs. Example: mkcmd('mv', 'source', 'dest') will give 'mv source dest'.""" - return ' '.join([str(arg) for arg in args]) + return " ".join([str(arg) for arg in args]) # ---------------------------------------------- shape-related functions ----------------------------------------------- @@ -591,7 +748,8 @@ def get_resample_shape(patch_shape, factor, n_channels=None): def add_axis(x, axis=0): """Add axis to a numpy array. :param x: input array - :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time.""" + :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time. + """ axis = reformat_to_list(axis) for ax in axis: x = np.expand_dims(x, axis=ax) @@ -606,7 +764,9 @@ def get_padding_margin(cropping, loss_cropping): n_dims = max(len(cropping), len(loss_cropping)) cropping = reformat_to_list(cropping, length=n_dims) loss_cropping = reformat_to_list(loss_cropping, length=n_dims) - padding_margin = [int((cropping[i] - loss_cropping[i]) / 2) for i in range(n_dims)] + padding_margin = [ + int((cropping[i] - loss_cropping[i]) / 2) for i in range(n_dims) + ] if len(padding_margin) == 1: padding_margin = padding_margin[0] else: @@ -617,7 +777,9 @@ def get_padding_margin(cropping, loss_cropping): # -------------------------------------------- build affine matrices/tensors ------------------------------------------- -def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, shearing=None, translation=None): +def create_affine_transformation_matrix( + n_dims, scaling=None, rotation=None, shearing=None, translation=None +): """Create a 4x4 affine transformation matrix from specified values :param n_dims: integer, can either be 2 or 3. :param scaling: list of 3 scaling values @@ -635,14 +797,16 @@ def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, she T_scaling[np.arange(n_dims + 1), np.arange(n_dims + 1)] = np.append(scaling, 1) if shearing is not None: - shearing_index = np.ones((n_dims + 1, n_dims + 1), dtype='bool') - shearing_index[np.eye(n_dims + 1, dtype='bool')] = False + shearing_index = np.ones((n_dims + 1, n_dims + 1), dtype="bool") + shearing_index[np.eye(n_dims + 1, dtype="bool")] = False shearing_index[-1, :] = np.zeros((n_dims + 1)) shearing_index[:, -1] = np.zeros((n_dims + 1)) T_shearing[shearing_index] = shearing if translation is not None: - T_translation[np.arange(n_dims), n_dims * np.ones(n_dims, dtype='int')] = translation + T_translation[np.arange(n_dims), n_dims * np.ones(n_dims, dtype="int")] = ( + translation + ) if n_dims == 2: if rotation is None: @@ -650,8 +814,12 @@ def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, she else: rotation = np.asarray(rotation) * (math.pi / 180) T_rot = np.eye(n_dims + 1) - T_rot[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[0]), np.sin(rotation[0]), - np.sin(rotation[0]) * -1, np.cos(rotation[0])] + T_rot[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [ + np.cos(rotation[0]), + np.sin(rotation[0]), + np.sin(rotation[0]) * -1, + np.cos(rotation[0]), + ] return T_translation @ T_rot @ T_shearing @ T_scaling else: @@ -661,92 +829,138 @@ def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, she else: rotation = np.asarray(rotation) * (math.pi / 180) T_rot1 = np.eye(n_dims + 1) - T_rot1[np.array([1, 2, 1, 2]), np.array([1, 1, 2, 2])] = [np.cos(rotation[0]), np.sin(rotation[0]), - np.sin(rotation[0]) * -1, np.cos(rotation[0])] + T_rot1[np.array([1, 2, 1, 2]), np.array([1, 1, 2, 2])] = [ + np.cos(rotation[0]), + np.sin(rotation[0]), + np.sin(rotation[0]) * -1, + np.cos(rotation[0]), + ] T_rot2 = np.eye(n_dims + 1) - T_rot2[np.array([0, 2, 0, 2]), np.array([0, 0, 2, 2])] = [np.cos(rotation[1]), np.sin(rotation[1]) * -1, - np.sin(rotation[1]), np.cos(rotation[1])] + T_rot2[np.array([0, 2, 0, 2]), np.array([0, 0, 2, 2])] = [ + np.cos(rotation[1]), + np.sin(rotation[1]) * -1, + np.sin(rotation[1]), + np.cos(rotation[1]), + ] T_rot3 = np.eye(n_dims + 1) - T_rot3[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[2]), np.sin(rotation[2]), - np.sin(rotation[2]) * -1, np.cos(rotation[2])] + T_rot3[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [ + np.cos(rotation[2]), + np.sin(rotation[2]), + np.sin(rotation[2]) * -1, + np.cos(rotation[2]), + ] return T_translation @ T_rot3 @ T_rot2 @ T_rot1 @ T_shearing @ T_scaling -def sample_affine_transform(batchsize, - n_dims, - rotation_bounds=False, - scaling_bounds=False, - shearing_bounds=False, - translation_bounds=False, - enable_90_rotations=False): +def sample_affine_transform( + batchsize, + n_dims, + rotation_bounds=False, + scaling_bounds=False, + shearing_bounds=False, + translation_bounds=False, + enable_90_rotations=False, +): """build batchsize x 4 x 4 tensor representing an affine transformation in homogeneous coordinates. If return_inv is True, also returns the inverse of the created affine matrix.""" if (rotation_bounds is not False) | (enable_90_rotations is not False): if n_dims == 2: if rotation_bounds is not False: - rotation = draw_value_from_distribution(rotation_bounds, - size=1, - default_range=15.0, - return_as_tensor=True, - batchsize=batchsize) + rotation = draw_value_from_distribution( + rotation_bounds, + size=1, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize, + ) else: - rotation = tf.zeros(tf.concat([batchsize, tf.ones(1, dtype='int32')], axis=0)) + rotation = tf.zeros( + tf.concat([batchsize, tf.ones(1, dtype="int32")], axis=0) + ) else: # n_dims = 3 if rotation_bounds is not False: - rotation = draw_value_from_distribution(rotation_bounds, - size=n_dims, - default_range=15.0, - return_as_tensor=True, - batchsize=batchsize) + rotation = draw_value_from_distribution( + rotation_bounds, + size=n_dims, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize, + ) else: - rotation = tf.zeros(tf.concat([batchsize, 3 * tf.ones(1, dtype='int32')], axis=0)) + rotation = tf.zeros( + tf.concat([batchsize, 3 * tf.ones(1, dtype="int32")], axis=0) + ) if enable_90_rotations: - rotation = tf.cast(tf.random.uniform(tf.shape(rotation), maxval=4, dtype='int32') * 90, 'float32') \ - + rotation + rotation = ( + tf.cast( + tf.random.uniform(tf.shape(rotation), maxval=4, dtype="int32") * 90, + "float32", + ) + + rotation + ) T_rot = create_rotation_transform(rotation, n_dims) else: - T_rot = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), - tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T_rot = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) if shearing_bounds is not False: - shearing = draw_value_from_distribution(shearing_bounds, - size=n_dims ** 2 - n_dims, - default_range=.01, - return_as_tensor=True, - batchsize=batchsize) + shearing = draw_value_from_distribution( + shearing_bounds, + size=n_dims**2 - n_dims, + default_range=0.01, + return_as_tensor=True, + batchsize=batchsize, + ) T_shearing = create_shearing_transform(shearing, n_dims) else: - T_shearing = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), - tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T_shearing = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) if scaling_bounds is not False: - scaling = draw_value_from_distribution(scaling_bounds, - size=n_dims, - centre=1, - default_range=.15, - return_as_tensor=True, - batchsize=batchsize) + scaling = draw_value_from_distribution( + scaling_bounds, + size=n_dims, + centre=1, + default_range=0.15, + return_as_tensor=True, + batchsize=batchsize, + ) T_scaling = tf.linalg.diag(scaling) else: - T_scaling = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), - tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T_scaling = tf.tile( + tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0), + ) T = tf.matmul(T_scaling, tf.matmul(T_shearing, T_rot)) if translation_bounds is not False: - translation = draw_value_from_distribution(translation_bounds, - size=n_dims, - default_range=5, - return_as_tensor=True, - batchsize=batchsize) + translation = draw_value_from_distribution( + translation_bounds, + size=n_dims, + default_range=5, + return_as_tensor=True, + batchsize=batchsize, + ) T = tf.concat([T, tf.expand_dims(translation, axis=-1)], axis=-1) else: - T = tf.concat([T, tf.zeros(tf.concat([tf.shape(T)[:2], tf.ones(1, dtype='int32')], 0))], axis=-1) + T = tf.concat( + [T, tf.zeros(tf.concat([tf.shape(T)[:2], tf.ones(1, dtype="int32")], 0))], + axis=-1, + ) # build rigid transform - T_last_row = tf.expand_dims(tf.concat([tf.zeros((1, n_dims)), tf.ones((1, 1))], axis=1), 0) - T_last_row = tf.tile(T_last_row, tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T_last_row = tf.expand_dims( + tf.concat([tf.zeros((1, n_dims)), tf.ones((1, 1))], axis=1), 0 + ) + T_last_row = tf.tile( + T_last_row, tf.concat([batchsize, tf.ones(2, dtype="int32")], axis=0) + ) T = tf.concat([T, T_last_row], axis=1) return T @@ -758,38 +972,93 @@ def create_rotation_transform(rotation, n_dims): if n_dims == 3: shape = tf.shape(tf.expand_dims(rotation[..., 0], -1)) - Rx_row0 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([1., 0., 0.]), 0), shape), axis=1) - Rx_row1 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.cos(rotation[..., 0]), -1), - tf.expand_dims(-tf.sin(rotation[..., 0]), -1)], axis=-1) - Rx_row2 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.sin(rotation[..., 0]), -1), - tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1) + Rx_row0 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([1.0, 0.0, 0.0]), 0), shape), + axis=1, + ) + Rx_row1 = tf.stack( + [ + tf.zeros(shape), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(-tf.sin(rotation[..., 0]), -1), + ], + axis=-1, + ) + Rx_row2 = tf.stack( + [ + tf.zeros(shape), + tf.expand_dims(tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + ], + axis=-1, + ) Rx = tf.concat([Rx_row0, Rx_row1, Rx_row2], axis=1) - Ry_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 1]), -1), tf.zeros(shape), - tf.expand_dims(tf.sin(rotation[..., 1]), -1)], axis=-1) - Ry_row1 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 1., 0.]), 0), shape), axis=1) - Ry_row2 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 1]), -1), tf.zeros(shape), - tf.expand_dims(tf.cos(rotation[..., 1]), -1)], axis=-1) + Ry_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 1]), -1), + tf.zeros(shape), + tf.expand_dims(tf.sin(rotation[..., 1]), -1), + ], + axis=-1, + ) + Ry_row1 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([0.0, 1.0, 0.0]), 0), shape), + axis=1, + ) + Ry_row2 = tf.stack( + [ + tf.expand_dims(-tf.sin(rotation[..., 1]), -1), + tf.zeros(shape), + tf.expand_dims(tf.cos(rotation[..., 1]), -1), + ], + axis=-1, + ) Ry = tf.concat([Ry_row0, Ry_row1, Ry_row2], axis=1) - Rz_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 2]), -1), - tf.expand_dims(-tf.sin(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1) - Rz_row1 = tf.stack([tf.expand_dims(tf.sin(rotation[..., 2]), -1), - tf.expand_dims(tf.cos(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1) - Rz_row2 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 0., 1.]), 0), shape), axis=1) + Rz_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 2]), -1), + tf.expand_dims(-tf.sin(rotation[..., 2]), -1), + tf.zeros(shape), + ], + axis=-1, + ) + Rz_row1 = tf.stack( + [ + tf.expand_dims(tf.sin(rotation[..., 2]), -1), + tf.expand_dims(tf.cos(rotation[..., 2]), -1), + tf.zeros(shape), + ], + axis=-1, + ) + Rz_row2 = tf.expand_dims( + tf.tile(tf.expand_dims(tf.convert_to_tensor([0.0, 0.0, 1.0]), 0), shape), + axis=1, + ) Rz = tf.concat([Rz_row0, Rz_row1, Rz_row2], axis=1) T_rot = tf.matmul(tf.matmul(Rx, Ry), Rz) elif n_dims == 2: - R_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 0]), -1), - tf.expand_dims(tf.sin(rotation[..., 0]), -1)], axis=-1) - R_row1 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 0]), -1), - tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1) + R_row0 = tf.stack( + [ + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(tf.sin(rotation[..., 0]), -1), + ], + axis=-1, + ) + R_row1 = tf.stack( + [ + tf.expand_dims(-tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1), + ], + axis=-1, + ) T_rot = tf.concat([R_row0, R_row1], axis=1) else: - raise Exception('only supports 2 or 3D.') + raise Exception("only supports 2 or 3D.") return T_rot @@ -798,20 +1067,42 @@ def create_shearing_transform(shearing, n_dims): """build shearing transform from 2d/3d shearing coefficients""" shape = tf.shape(tf.expand_dims(shearing[..., 0], -1)) if n_dims == 3: - shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1), - tf.expand_dims(shearing[..., 1], -1)], axis=-1) - shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 2], -1), tf.ones(shape), - tf.expand_dims(shearing[..., 3], -1)], axis=-1) - shearing_row2 = tf.stack([tf.expand_dims(shearing[..., 4], -1), tf.expand_dims(shearing[..., 5], -1), - tf.ones(shape)], axis=-1) + shearing_row0 = tf.stack( + [ + tf.ones(shape), + tf.expand_dims(shearing[..., 0], -1), + tf.expand_dims(shearing[..., 1], -1), + ], + axis=-1, + ) + shearing_row1 = tf.stack( + [ + tf.expand_dims(shearing[..., 2], -1), + tf.ones(shape), + tf.expand_dims(shearing[..., 3], -1), + ], + axis=-1, + ) + shearing_row2 = tf.stack( + [ + tf.expand_dims(shearing[..., 4], -1), + tf.expand_dims(shearing[..., 5], -1), + tf.ones(shape), + ], + axis=-1, + ) T_shearing = tf.concat([shearing_row0, shearing_row1, shearing_row2], axis=1) elif n_dims == 2: - shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1)], axis=-1) - shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 1], -1), tf.ones(shape)], axis=-1) + shearing_row0 = tf.stack( + [tf.ones(shape), tf.expand_dims(shearing[..., 0], -1)], axis=-1 + ) + shearing_row1 = tf.stack( + [tf.expand_dims(shearing[..., 1], -1), tf.ones(shape)], axis=-1 + ) T_shearing = tf.concat([shearing_row0, shearing_row1], axis=1) else: - raise Exception('only supports 2 or 3D.') + raise Exception("only supports 2 or 3D.") return T_shearing @@ -819,16 +1110,18 @@ def create_shearing_transform(shearing, n_dims): def infer(x): - """ Try to parse input to float. If it fails, tries boolean, and otherwise keep it as string """ + """Try to parse input to float. If it fails, tries boolean, and otherwise keep it as string""" try: x = float(x) except ValueError: - if x == 'False': + if x == "False": x = False - elif x == 'True': + elif x == "True": x = True elif not isinstance(x, str): - raise TypeError('input should be an int/float/boolean/str, had {}'.format(type(x))) + raise TypeError( + "input should be an int/float/boolean/str, had {}".format(type(x)) + ) return x @@ -840,7 +1133,7 @@ class LoopInfo: processing i/total remaining time: hh:mm:ss """ - def __init__(self, n_iterations, spacing=10, text='processing', print_time=False): + def __init__(self, n_iterations, spacing=10, text="processing", print_time=False): """ :param n_iterations: total number of iterations of the for loop. :param spacing: frequency at which the update info will be printed on screen. @@ -872,23 +1165,32 @@ def update(self, idx): # print text if idx == 0: - print(self.text + ' 1/{}'.format(self.n_iterations)) + print(self.text + " 1/{}".format(self.n_iterations)) elif idx % self.spacing == self.spacing - 1: - iteration = str(idx + 1) + '/' + str(self.n_iterations) + iteration = str(idx + 1) + "/" + str(self.n_iterations) if self.print_time: # estimate remaining time max_duration = np.max(self.iteration_durations) - average_duration = np.mean(self.iteration_durations[self.iteration_durations > .01 * max_duration]) + average_duration = np.mean( + self.iteration_durations[ + self.iteration_durations > 0.01 * max_duration + ] + ) remaining_time = int(average_duration * (self.n_iterations - idx)) # print total remaining time only if it is greater than 1s or if it was previously printed if (remaining_time > 1) | self.print_previous_time: eta = str(timedelta(seconds=remaining_time)) - print(self.text + ' {:<{x}} remaining time: {}'.format(iteration, eta, x=self.align)) + print( + self.text + + " {:<{x}} remaining time: {}".format( + iteration, eta, x=self.align + ) + ) self.print_previous_time = True else: - print(self.text + ' {}'.format(iteration)) + print(self.text + " {}".format(iteration)) else: - print(self.text + ' {}'.format(iteration)) + print(self.text + " {}".format(iteration)) def get_mapping_lut(source, dest=None): @@ -896,18 +1198,20 @@ def get_mapping_lut(source, dest=None): If the second list is not given, we assume it is equal to [0, ..., N-1].""" # initialise - source = np.array(reformat_to_list(source), dtype='int32') + source = np.array(reformat_to_list(source), dtype="int32") n_labels = source.shape[0] # build new label list if necessary if dest is None: - dest = np.arange(n_labels, dtype='int32') + dest = np.arange(n_labels, dtype="int32") else: - assert len(source) == len(dest), 'label_list and new_label_list should have the same length' - dest = np.array(reformat_to_list(dest, dtype='int')) + assert len(source) == len( + dest + ), "label_list and new_label_list should have the same length" + dest = np.array(reformat_to_list(dest, dtype="int")) # build look-up table - lut = np.zeros(np.max(source) + 1, dtype='int32') + lut = np.zeros(np.max(source) + 1, dtype="int32") for source, dest in zip(source, dest): lut[source] = dest @@ -925,7 +1229,7 @@ def build_training_generator(gen, batchsize): yield inputs, target -def find_closest_number_divisible_by_m(n, m, answer_type='lower'): +def find_closest_number_divisible_by_m(n, m, answer_type="lower"): """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns values lower than n), or 'higher' (only returns values higher than m).""" if n % m == 0: @@ -934,14 +1238,16 @@ def find_closest_number_divisible_by_m(n, m, answer_type='lower'): q = int(n / m) lower = q * m higher = (q + 1) * m - if answer_type == 'lower': + if answer_type == "lower": return lower - elif answer_type == 'higher': + elif answer_type == "higher": return higher - elif answer_type == 'closer': + elif answer_type == "closer": return lower if (n - lower) < (higher - n) else higher else: - raise Exception('answer_type should be lower, higher, or closer, had : %s' % answer_type) + raise Exception( + "answer_type should be lower, higher, or closer, had : %s" % answer_type + ) def build_binary_structure(connectivity, n_dims, shape=None): @@ -958,14 +1264,16 @@ def build_binary_structure(connectivity, n_dims, shape=None): return struct -def draw_value_from_distribution(hyperparameter, - size=1, - distribution='uniform', - centre=0., - default_range=10.0, - positive_only=False, - return_as_tensor=False, - batchsize=None): +def draw_value_from_distribution( + hyperparameter, + size=1, + distribution="uniform", + centre=0.0, + default_range=10.0, + positive_only=False, + return_as_tensor=False, + batchsize=None, +): """Sample values from a uniform, or normal distribution of given hyperparameters. These hyperparameters are to the number of 2 in both uniform and normal cases. :param hyperparameter: values of the hyperparameters. Can either be: @@ -1001,47 +1309,73 @@ def draw_value_from_distribution(hyperparameter, hyperparameter = load_array_if_path(hyperparameter, load_as_numpy=True) if not isinstance(hyperparameter, np.ndarray): if hyperparameter is None: - hyperparameter = np.array([[centre - default_range] * size, [centre + default_range] * size]) + hyperparameter = np.array( + [[centre - default_range] * size, [centre + default_range] * size] + ) elif isinstance(hyperparameter, (int, float)): - hyperparameter = np.array([[centre - hyperparameter] * size, [centre + hyperparameter] * size]) + hyperparameter = np.array( + [[centre - hyperparameter] * size, [centre + hyperparameter] * size] + ) elif isinstance(hyperparameter, (list, tuple)): - assert len(hyperparameter) == 2, 'if list, parameter_range should be of length 2.' + assert ( + len(hyperparameter) == 2 + ), "if list, parameter_range should be of length 2." hyperparameter = np.transpose(np.tile(np.array(hyperparameter), (size, 1))) else: - raise ValueError('parameter_range should either be None, a number, a sequence, or a numpy array.') + raise ValueError( + "parameter_range should either be None, a number, a sequence, or a numpy array." + ) elif isinstance(hyperparameter, np.ndarray): - assert hyperparameter.shape[0] % 2 == 0, 'number of rows of parameter_range should be divisible by 2' + assert ( + hyperparameter.shape[0] % 2 == 0 + ), "number of rows of parameter_range should be divisible by 2" n_modalities = int(hyperparameter.shape[0] / 2) modality_idx = 2 * np.random.randint(n_modalities) - hyperparameter = hyperparameter[modality_idx: modality_idx + 2, :] + hyperparameter = hyperparameter[modality_idx : modality_idx + 2, :] # draw values as tensor if return_as_tensor: - shape = KL.Lambda(lambda x: tf.convert_to_tensor(hyperparameter.shape[1], 'int32'))([]) + shape = KL.Lambda( + lambda x: tf.convert_to_tensor(hyperparameter.shape[1], "int32") + )([]) if batchsize is not None: - shape = KL.Lambda(lambda x: tf.concat([x[0], tf.expand_dims(x[1], axis=0)], axis=0))([batchsize, shape]) - if distribution == 'uniform': - parameter_value = KL.Lambda(lambda x: tf.random.uniform(shape=x, - minval=hyperparameter[0, :], - maxval=hyperparameter[1, :]))(shape) - elif distribution == 'normal': - parameter_value = KL.Lambda(lambda x: tf.random.normal(shape=x, - mean=hyperparameter[0, :], - stddev=hyperparameter[1, :]))(shape) + shape = KL.Lambda( + lambda x: tf.concat([x[0], tf.expand_dims(x[1], axis=0)], axis=0) + )([batchsize, shape]) + if distribution == "uniform": + parameter_value = KL.Lambda( + lambda x: tf.random.uniform( + shape=x, minval=hyperparameter[0, :], maxval=hyperparameter[1, :] + ) + )(shape) + elif distribution == "normal": + parameter_value = KL.Lambda( + lambda x: tf.random.normal( + shape=x, mean=hyperparameter[0, :], stddev=hyperparameter[1, :] + ) + )(shape) else: - raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.") + raise ValueError( + "Distribution not supported, should be 'uniform' or 'normal'." + ) if positive_only: parameter_value = KL.Lambda(lambda x: K.clip(x, 0, None))(parameter_value) # draw values as numpy array else: - if distribution == 'uniform': - parameter_value = np.random.uniform(low=hyperparameter[0, :], high=hyperparameter[1, :]) - elif distribution == 'normal': - parameter_value = np.random.normal(loc=hyperparameter[0, :], scale=hyperparameter[1, :]) + if distribution == "uniform": + parameter_value = np.random.uniform( + low=hyperparameter[0, :], high=hyperparameter[1, :] + ) + elif distribution == "normal": + parameter_value = np.random.normal( + loc=hyperparameter[0, :], scale=hyperparameter[1, :] + ) else: - raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.") + raise ValueError( + "Distribution not supported, should be 'uniform' or 'normal'." + ) if positive_only: parameter_value[parameter_value < 0] = 0 @@ -1053,5 +1387,5 @@ def build_exp(x, first, last, fix_point): # first = f(0), last = f(+inf), fix_point = [x0, f(x0))] a = last b = first - last - c = - (1 / fix_point[0]) * np.log((fix_point[1] - last) / (first - last)) + c = -(1 / fix_point[0]) * np.log((fix_point[1] - last) / (first - last)) return a + b * np.exp(-c * x) diff --git a/nobrainer/ext/neuron/__init__.py b/nobrainer/ext/neuron/__init__.py index 2f28f4d4..6a418b95 100644 --- a/nobrainer/ext/neuron/__init__.py +++ b/nobrainer/ext/neuron/__init__.py @@ -1,3 +1 @@ -from . import layers -from . import models -from . import utils +from . import layers, models, utils diff --git a/nobrainer/ext/neuron/layers.py b/nobrainer/ext/neuron/layers.py index 61b46a78..9f4e1821 100644 --- a/nobrainer/ext/neuron/layers.py +++ b/nobrainer/ext/neuron/layers.py @@ -1,9 +1,9 @@ """ tensorflow/keras utilities for the neuron project -If you use this code, please cite +If you use this code, please cite Dalca AV, Guttag J, Sabuncu MR -Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, CVPR 2018 or for the transformation/integration functions: @@ -16,24 +16,32 @@ License: GPLv3 """ -# third party -import tensorflow as tf +from copy import deepcopy + from keras import backend as K from keras.layers import Layer -from copy import deepcopy + +# third party +import tensorflow as tf # local -from nobrainer.ext.neuron.utils import transform, resize, integrate_vec, affine_to_shift, combine_non_linear_and_aff_to_shift +from nobrainer.ext.neuron.utils import ( + affine_to_shift, + combine_non_linear_and_aff_to_shift, + integrate_vec, + resize, + transform, +) class SpatialTransformer(Layer): """ N-D Spatial Transformer Tensorflow / Keras Layer - The Layer can handle both affine and dense transforms. + The Layer can handle both affine and dense transforms. Both transforms are meant to give a 'shift' from the current position. Therefore, a dense transform gives displacements (not absolute locations) at each voxel, - and an affine transform gives the *difference* of the affine matrix from + and an affine transform gives the *difference* of the affine matrix from the identity matrix. If you find this function useful, please cite: @@ -41,25 +49,23 @@ class SpatialTransformer(Layer): Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu MICCAI 2018. - Originally, this code was based on voxelmorph code, which - was in turn transformed to be dense with the help of (affine) STN code + Originally, this code was based on voxelmorph code, which + was in turn transformed to be dense with the help of (affine) STN code via https://github.com/kevinzakka/spatial-transformer-network - Since then, we've re-written the code to be generalized to any + Since then, we've re-written the code to be generalized to any dimensions, and along the way wrote grid and interpolation functions """ - def __init__(self, - interp_method='linear', - indexing='ij', - single_transform=False, - **kwargs): + def __init__( + self, interp_method="linear", indexing="ij", single_transform=False, **kwargs + ): """ - Parameters: + Parameters: interp_method: 'linear' or 'nearest' single_transform: whether a single transform supplied for the whole batch indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian) - 'xy' indexing will have the first two entries of the flow + 'xy' indexing will have the first two entries of the flow (along last axis) flipped compared to 'ij' indexing """ self.interp_method = interp_method @@ -68,7 +74,10 @@ def __init__(self, self.single_transform = single_transform self.is_affine = list() - assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" + assert indexing in [ + "ij", + "xy", + ], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" self.indexing = indexing super(self.__class__, self).__init__(**kwargs) @@ -93,31 +102,41 @@ def build(self, input_shape): """ if len(input_shape) > 3: - raise Exception('Spatial Transformer must be called on a list of min length 2 and max length 3.' - 'First argument is the image followed by the affine and non linear transforms.') + raise Exception( + "Spatial Transformer must be called on a list of min length 2 and max length 3." + "First argument is the image followed by the affine and non linear transforms." + ) # set up number of dimensions self.ndims = len(input_shape[0]) - 2 self.inshape = input_shape trf_shape = [trans_shape[1:] for trans_shape in input_shape[1:]] - for (i, shape) in enumerate(trf_shape): + for i, shape in enumerate(trf_shape): # the transform is an affine iff: # it's a 1D Tensor [dense transforms need to be at least ndims + 1] # it's a 2D Tensor and shape == [N+1, N+1]. - self.is_affine.append(len(shape) == 1 or - (len(shape) == 2 and all([f == (self.ndims + 1) for f in shape]))) + self.is_affine.append( + len(shape) == 1 + or (len(shape) == 2 and all([f == (self.ndims + 1) for f in shape])) + ) # check sizes if self.is_affine[i] and len(shape) == 1: ex = self.ndims * (self.ndims + 1) if shape[0] != ex: - raise Exception('Expected flattened affine of len %d but got %d' % (ex, shape[0])) + raise Exception( + "Expected flattened affine of len %d but got %d" + % (ex, shape[0]) + ) if not self.is_affine[i]: if shape[-1] != self.ndims: - raise Exception('Offset flow field size expected: %d, found: %d' % (self.ndims, shape[-1])) + raise Exception( + "Offset flow field size expected: %d, found: %d" + % (self.ndims, shape[-1]) + ) # confirm built self.built = True @@ -129,17 +148,21 @@ def call(self, inputs, **kwargs): """ # check shapes - assert 1 < len(inputs) < 4, "inputs has to be len 2 or 3, found: %d" % len(inputs) + assert 1 < len(inputs) < 4, "inputs has to be len 2 or 3, found: %d" % len( + inputs + ) vol = inputs[0] trf = inputs[1:] # necessary for multi_gpu models... vol = K.reshape(vol, [-1, *self.inshape[0][1:]]) for i in range(len(trf)): - trf[i] = K.reshape(trf[i], [-1, *self.inshape[i+1][1:]]) + trf[i] = K.reshape(trf[i], [-1, *self.inshape[i + 1][1:]]) # reorder transforms, non-linear first and affine second - ind_nonlinear_linear = [i[0] for i in sorted(enumerate(self.is_affine), key=lambda x:x[1])] + ind_nonlinear_linear = [ + i[0] for i in sorted(enumerate(self.is_affine), key=lambda x: x[1]) + ] self.is_affine = [self.is_affine[i] for i in ind_nonlinear_linear] self.inshape = [self.inshape[i] for i in ind_nonlinear_linear] trf = [trf[i] for i in ind_nonlinear_linear] @@ -148,13 +171,21 @@ def call(self, inputs, **kwargs): if len(trf) == 1: trf = trf[0] if self.is_affine[0]: - trf = tf.map_fn(lambda x: self._single_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32) + trf = tf.map_fn( + lambda x: self._single_aff_to_shift(x, vol.shape[1:-1]), + trf, + dtype=tf.float32, + ) # combine non-linear and affine to obtain a single deformation field elif len(trf) == 2: - trf = tf.map_fn(lambda x: self._non_linear_and_aff_to_shift(x, vol.shape[1:-1]), trf, dtype=tf.float32) + trf = tf.map_fn( + lambda x: self._non_linear_and_aff_to_shift(x, vol.shape[1:-1]), + trf, + dtype=tf.float32, + ) # prepare location shift - if self.indexing == 'xy': # shift the first two dimensions + if self.indexing == "xy": # shift the first two dimensions trf_split = tf.split(trf, trf.shape[-1], axis=-1) trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]] trf = tf.concat(trf_lst, -1) @@ -183,8 +214,8 @@ class VecInt(Layer): """ Vector Integration Layer - Enables vector integration via several methods - (ode or quadrature for time-dependent vector fields, + Enables vector integration via several methods + (ode or quadrature for time-dependent vector fields, scaling and squaring for stationary fields) If you find this function useful, please cite: @@ -193,10 +224,17 @@ class VecInt(Layer): MICCAI 2018. """ - def __init__(self, indexing='ij', method='ss', int_steps=7, out_time_pt=1, - ode_args=None, - odeint_fn=None, **kwargs): - """ + def __init__( + self, + indexing="ij", + method="ss", + int_steps=7, + out_time_pt=1, + ode_args=None, + odeint_fn=None, + **kwargs + ): + """ Parameters: method can be any of the methods in neuron.utils.integrate_vec indexing can be 'xy' (switches first two dimensions) or 'ij' @@ -204,7 +242,10 @@ def __init__(self, indexing='ij', method='ss', int_steps=7, out_time_pt=1, out_time_pt is time point at which to output if using odeint integration """ - assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" + assert indexing in [ + "ij", + "xy", + ], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)" self.indexing = indexing self.method = method self.int_steps = int_steps @@ -213,7 +254,7 @@ def __init__(self, indexing='ij', method='ss', int_steps=7, out_time_pt=1, self.odeint_fn = odeint_fn # if none then will use a tensorflow function self.ode_args = ode_args if ode_args is None: - self.ode_args = {'rtol': 1e-6, 'atol': 1e-12} + self.ode_args = {"rtol": 1e-6, "atol": 1e-12} super(self.__class__, self).__init__(**kwargs) def get_config(self): @@ -236,7 +277,10 @@ def build(self, input_shape): self.inshape = trf_shape if trf_shape[-1] != len(trf_shape) - 2: - raise Exception('transform ndims %d does not match expected ndims %d' % (trf_shape[-1], len(trf_shape) - 2)) + raise Exception( + "transform ndims %d does not match expected ndims %d" + % (trf_shape[-1], len(trf_shape) - 2) + ) def call(self, inputs, **kwargs): if not isinstance(inputs, (list, tuple)): @@ -247,13 +291,19 @@ def call(self, inputs, **kwargs): loc_shift = K.reshape(loc_shift, [-1, *self.inshape[1:]]) # prepare location shift - if self.indexing == 'xy': # shift the first two dimensions + if self.indexing == "xy": # shift the first two dimensions loc_shift_split = tf.split(loc_shift, loc_shift.shape[-1], axis=-1) - loc_shift_lst = [loc_shift_split[1], loc_shift_split[0], *loc_shift_split[2:]] + loc_shift_lst = [ + loc_shift_split[1], + loc_shift_split[0], + *loc_shift_split[2:], + ] loc_shift = tf.concat(loc_shift_lst, -1) if len(inputs) > 1: - assert self.out_time_pt is None, 'out_time_pt should be None if providing batch_based out_time_pt' + assert ( + self.out_time_pt is None + ), "out_time_pt should be None if providing batch_based out_time_pt" # map transform across batch out = tf.map_fn(self._single_int, [loc_shift] + inputs[1:], dtype=tf.float32) @@ -265,11 +315,14 @@ def _single_int(self, inputs): out_time_pt = self.out_time_pt if len(inputs) == 2: out_time_pt = inputs[1] - return integrate_vec(vel, method=self.method, - nb_steps=self.int_steps, - ode_args=self.ode_args, - out_time_pt=out_time_pt, - odeint_fn=self.odeint_fn) + return integrate_vec( + vel, + method=self.method, + nb_steps=self.int_steps, + ode_args=self.ode_args, + out_time_pt=out_time_pt, + odeint_fn=self.odeint_fn, + ) class Resize(Layer): @@ -281,19 +334,15 @@ class Resize(Layer): Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation,Dalca AV, Guttag J, Sabuncu MR CVPR 2018 - Since then, we've re-written the code to be generalized to any + Since then, we've re-written the code to be generalized to any dimensions, and along the way wrote grid and interpolation functions """ - def __init__(self, - zoom_factor=None, - size=None, - interp_method='linear', - **kwargs): + def __init__(self, zoom_factor=None, size=None, interp_method="linear", **kwargs): """ - Parameters: + Parameters: interp_method: 'linear' or 'nearest' - 'xy' indexing will have the first two entries of the flow + 'xy' indexing will have the first two entries of the flow (along last axis) flipped compared to 'ij' indexing """ self.zoom_factor = zoom_factor @@ -320,7 +369,7 @@ def build(self, input_shape): """ if isinstance(input_shape[0], (list, tuple)) and len(input_shape) > 1: - raise Exception('Resize must be called on a list of length 1.') + raise Exception("Resize must be called on a list of length 1.") if isinstance(input_shape[0], (list, tuple)): input_shape = input_shape[0] @@ -336,10 +385,15 @@ def build(self, input_shape): self.zoom_factor0 = [0] * self.ndims elif isinstance(self.zoom_factor, (list, tuple)): self.zoom_factor0 = deepcopy(self.zoom_factor) - assert len(self.zoom_factor0) == self.ndims, \ - 'zoom factor length {} does not match number of dimensions {}'.format(len(self.zoom_factor), self.ndims) + assert ( + len(self.zoom_factor0) == self.ndims + ), "zoom factor length {} does not match number of dimensions {}".format( + len(self.zoom_factor), self.ndims + ) else: - raise Exception('zoom_factor should be an int or a list/tuple of int (or None if size is not set to None)') + raise Exception( + "zoom_factor should be an int or a list/tuple of int (or None if size is not set to None)" + ) # check size if isinstance(self.size, int): @@ -348,10 +402,15 @@ def build(self, input_shape): self.size0 = [0] * self.ndims elif isinstance(self.size, (list, tuple)): self.size0 = deepcopy(self.size) - assert len(self.size0) == self.ndims, \ - 'size length {} does not match number of dimensions {}'.format(len(self.size0), self.ndims) + assert ( + len(self.size0) == self.ndims + ), "size length {} does not match number of dimensions {}".format( + len(self.size0), self.ndims + ) else: - raise Exception('size should be an int or a list/tuple of int (or None if zoom_factor is not set to None)') + raise Exception( + "size should be an int or a list/tuple of int (or None if zoom_factor is not set to None)" + ) # confirm built self.built = True @@ -376,9 +435,14 @@ def call(self, inputs, **kwargs): # set value of missing size or zoom_factor if not any(self.zoom_factor0): - self.zoom_factor0 = [self.size0[i] / self.inshape[i+1] for i in range(self.ndims)] + self.zoom_factor0 = [ + self.size0[i] / self.inshape[i + 1] for i in range(self.ndims) + ] else: - self.size0 = [int(self.inshape[f+1] * self.zoom_factor0[f]) for f in range(self.ndims)] + self.size0 = [ + int(self.inshape[f + 1] * self.zoom_factor0[f]) + for f in range(self.ndims) + ] # map transform across batch return tf.map_fn(self._single_resize, vol, dtype=vol.dtype) @@ -386,12 +450,16 @@ def call(self, inputs, **kwargs): def compute_output_shape(self, input_shape): output_shape = [input_shape[0]] - output_shape += [int(input_shape[1:-1][f] * self.zoom_factor0[f]) for f in range(self.ndims)] + output_shape += [ + int(input_shape[1:-1][f] * self.zoom_factor0[f]) for f in range(self.ndims) + ] output_shape += [input_shape[-1]] return tuple(output_shape) def _single_resize(self, inputs): - return resize(inputs, self.zoom_factor0, self.size0, interp_method=self.interp_method) + return resize( + inputs, self.zoom_factor0, self.size0, interp_method=self.interp_method + ) # Zoom naming of resize, to match scipy's naming @@ -402,13 +470,14 @@ def _single_resize(self, inputs): # "Local" layers -- layers with parameters at each voxel ######################################################### + class LocalBias(Layer): - """ + """ Local bias layer: each pixel/voxel has its own bias operation (one parameter) out[v] = in[v] + b """ - def __init__(self, my_initializer='RandomNormal', biasmult=1.0, **kwargs): + def __init__(self, my_initializer="RandomNormal", biasmult=1.0, **kwargs): self.initializer = my_initializer self.biasmult = biasmult self.kernel = None @@ -422,10 +491,12 @@ def get_config(self): def build(self, input_shape): # Create a trainable weight variable for this layer. - self.kernel = self.add_weight(name='kernel', - shape=input_shape[1:], - initializer=self.initializer, - trainable=True) + self.kernel = self.add_weight( + name="kernel", + shape=input_shape[1:], + initializer=self.initializer, + trainable=True, + ) super(LocalBias, self).build(input_shape) # Be sure to call this somewhere! def call(self, x, **kwargs): diff --git a/nobrainer/ext/neuron/models.py b/nobrainer/ext/neuron/models.py index 9b5c87ed..a13a9167 100644 --- a/nobrainer/ext/neuron/models.py +++ b/nobrainer/ext/neuron/models.py @@ -1,9 +1,9 @@ """ tensorflow/keras utilities for the neuron project -If you use this code, please cite +If you use this code, please cite Dalca AV, Guttag J, Sabuncu MR -Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, CVPR 2018 Contact: adalca [at] csail [dot] mit [dot] edu @@ -12,39 +12,42 @@ import sys -from nobrainer.ext.neuron import layers +import keras +import keras.backend as K +import keras.layers as KL +from keras.models import Model # third party import numpy as np import tensorflow as tf -import keras -import keras.layers as KL -from keras.models import Model -import keras.backend as K + +from nobrainer.ext.neuron import layers -def unet(nb_features, - input_shape, - nb_levels, - conv_size, - nb_labels, - name='unet', - prefix=None, - feat_mult=1, - pool_size=2, - use_logp=True, - padding='same', - dilation_rate_mult=1, - activation='elu', - skip_n_concatenations=0, - use_residuals=False, - final_pred_activation='softmax', - nb_conv_per_level=1, - add_prior_layer=False, - layer_nb_feats=None, - conv_dropout=0, - batch_norm=None, - input_model=None): +def unet( + nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name="unet", + prefix=None, + feat_mult=1, + pool_size=2, + use_logp=True, + padding="same", + dilation_rate_mult=1, + activation="elu", + skip_n_concatenations=0, + use_residuals=False, + final_pred_activation="softmax", + nb_conv_per_level=1, + add_prior_layer=False, + layer_nb_feats=None, + conv_dropout=0, + batch_norm=None, + input_model=None, +): """ unet-style keras model with an overdose of parametrization. @@ -53,13 +56,13 @@ def unet(nb_features, see below for `feat_mult` and `layer_nb_feats` for modifiers to this number input_shape: input layer shape, vector of size ndims + 1 (nb_channels) conv_size: the convolution kernel size - nb_levels: the number of Unet levels (number of downsamples) in the "encoder" + nb_levels: the number of Unet levels (number of downsamples) in the "encoder" (e.g. 4 would give you 4 levels in encoder, 4 in decoder) nb_labels: number of output channels name (default: 'unet'): the name of the network prefix (default: `name` value): prefix to be added to layer names feat_mult (default: 1) multiple for `nb_features` as we go down the encoder levels. - e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the + e.g. feat_mult of 2 and nb_features of 16 would yield 32 features in the second layer, 64 features in the third layer, etc. pool_size (default: 2): max pooling size (integer or list if specifying per dimension) skip_n_concatenations=0: enabled to skip concatenation links between contracting and expanding paths for the n @@ -91,85 +94,97 @@ def unet(nb_features, pool_size = (pool_size,) * ndims # get encoding model - enc_model = conv_enc(nb_features, - input_shape, - nb_levels, - conv_size, - name=model_name, - prefix=prefix, - feat_mult=feat_mult, - pool_size=pool_size, - padding=padding, - dilation_rate_mult=dilation_rate_mult, - activation=activation, - use_residuals=use_residuals, - nb_conv_per_level=nb_conv_per_level, - layer_nb_feats=layer_nb_feats, - conv_dropout=conv_dropout, - batch_norm=batch_norm, - input_model=input_model) + enc_model = conv_enc( + nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + layer_nb_feats=layer_nb_feats, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model, + ) # get decoder # use_skip_connections=True makes it a u-net - lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None - dec_model = conv_dec(nb_features, - [], - nb_levels, - conv_size, - nb_labels, - name=model_name, - prefix=prefix, - feat_mult=feat_mult, - pool_size=pool_size, - use_skip_connections=True, - skip_n_concatenations=skip_n_concatenations, - padding=padding, - dilation_rate_mult=dilation_rate_mult, - activation=activation, - use_residuals=use_residuals, - final_pred_activation='linear' if add_prior_layer else final_pred_activation, - nb_conv_per_level=nb_conv_per_level, - batch_norm=batch_norm, - layer_nb_feats=lnf, - conv_dropout=conv_dropout, - input_model=enc_model) + lnf = ( + layer_nb_feats[(nb_levels * nb_conv_per_level) :] + if layer_nb_feats is not None + else None + ) + dec_model = conv_dec( + nb_features, + [], + nb_levels, + conv_size, + nb_labels, + name=model_name, + prefix=prefix, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=True, + skip_n_concatenations=skip_n_concatenations, + padding=padding, + dilation_rate_mult=dilation_rate_mult, + activation=activation, + use_residuals=use_residuals, + final_pred_activation="linear" if add_prior_layer else final_pred_activation, + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + layer_nb_feats=lnf, + conv_dropout=conv_dropout, + input_model=enc_model, + ) final_model = dec_model if add_prior_layer: - final_model = add_prior(dec_model, - [*input_shape[:-1], nb_labels], - name=model_name + '_prior', - use_logp=use_logp, - final_pred_activation=final_pred_activation) + final_model = add_prior( + dec_model, + [*input_shape[:-1], nb_labels], + name=model_name + "_prior", + use_logp=use_logp, + final_pred_activation=final_pred_activation, + ) return final_model -def ae(nb_features, - input_shape, - nb_levels, - conv_size, - nb_labels, - enc_size, - name='ae', - feat_mult=1, - pool_size=2, - padding='same', - activation='elu', - use_residuals=False, - nb_conv_per_level=1, - batch_norm=None, - enc_batch_norm=None, - ae_type='conv', # 'dense', or 'conv' - enc_lambda_layers=None, - add_prior_layer=False, - use_logp=True, - conv_dropout=0, - include_mu_shift_layer=False, - single_model=False, # whether to return a single model, or a tuple of models that can be stacked. - final_pred_activation='softmax', - do_vae=False, - input_model=None): +def ae( + nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + enc_size, + name="ae", + feat_mult=1, + pool_size=2, + padding="same", + activation="elu", + use_residuals=False, + nb_conv_per_level=1, + batch_norm=None, + enc_batch_norm=None, + ae_type="conv", # 'dense', or 'conv' + enc_lambda_layers=None, + add_prior_layer=False, + use_logp=True, + conv_dropout=0, + include_mu_shift_layer=False, + single_model=False, # whether to return a single model, or a tuple of models that can be stacked. + final_pred_activation="softmax", + do_vae=False, + input_model=None, +): """Convolutional Auto-Encoder. Optionally Variational (if do_vae is set to True).""" # naming @@ -181,20 +196,22 @@ def ae(nb_features, pool_size = (pool_size,) * ndims # get encoding model - enc_model = conv_enc(nb_features, - input_shape, - nb_levels, - conv_size, - name=model_name, - feat_mult=feat_mult, - pool_size=pool_size, - padding=padding, - activation=activation, - use_residuals=use_residuals, - nb_conv_per_level=nb_conv_per_level, - conv_dropout=conv_dropout, - batch_norm=batch_norm, - input_model=input_model) + enc_model = conv_enc( + nb_features, + input_shape, + nb_levels, + conv_size, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + padding=padding, + activation=activation, + use_residuals=use_residuals, + nb_conv_per_level=nb_conv_per_level, + conv_dropout=conv_dropout, + batch_norm=batch_norm, + input_model=input_model, + ) # middle AE structure if single_model: @@ -203,16 +220,18 @@ def ae(nb_features, else: in_input_shape = enc_model.output.shape.as_list()[1:] in_model = None - mid_ae_model = single_ae(enc_size, - in_input_shape, - conv_size=conv_size, - name=model_name, - ae_type=ae_type, - input_model=in_model, - batch_norm=enc_batch_norm, - enc_lambda_layers=enc_lambda_layers, - include_mu_shift_layer=include_mu_shift_layer, - do_vae=do_vae) + mid_ae_model = single_ae( + enc_size, + in_input_shape, + conv_size=conv_size, + name=model_name, + ae_type=ae_type, + input_model=in_model, + batch_norm=enc_batch_norm, + enc_lambda_layers=enc_lambda_layers, + include_mu_shift_layer=include_mu_shift_layer, + do_vae=do_vae, + ) # decoder if single_model: @@ -221,31 +240,35 @@ def ae(nb_features, else: in_input_shape = mid_ae_model.output.shape.as_list()[1:] in_model = None - dec_model = conv_dec(nb_features, - in_input_shape, - nb_levels, - conv_size, - nb_labels, - name=model_name, - feat_mult=feat_mult, - pool_size=pool_size, - use_skip_connections=False, - padding=padding, - activation=activation, - use_residuals=use_residuals, - final_pred_activation='linear', - nb_conv_per_level=nb_conv_per_level, - batch_norm=batch_norm, - conv_dropout=conv_dropout, - input_model=in_model) + dec_model = conv_dec( + nb_features, + in_input_shape, + nb_levels, + conv_size, + nb_labels, + name=model_name, + feat_mult=feat_mult, + pool_size=pool_size, + use_skip_connections=False, + padding=padding, + activation=activation, + use_residuals=use_residuals, + final_pred_activation="linear", + nb_conv_per_level=nb_conv_per_level, + batch_norm=batch_norm, + conv_dropout=conv_dropout, + input_model=in_model, + ) if add_prior_layer: - dec_model = add_prior(dec_model, - [*input_shape[:-1], nb_labels], - name=model_name, - prefix=model_name + '_prior', - use_logp=use_logp, - final_pred_activation=final_pred_activation) + dec_model = add_prior( + dec_model, + [*input_shape[:-1], nb_labels], + name=model_name, + prefix=model_name + "_prior", + use_logp=use_logp, + final_pred_activation=final_pred_activation, + ) if single_model: return dec_model @@ -253,23 +276,25 @@ def ae(nb_features, return dec_model, mid_ae_model, enc_model -def conv_enc(nb_features, - input_shape, - nb_levels, - conv_size, - name=None, - prefix=None, - feat_mult=1, - pool_size=2, - dilation_rate_mult=1, - padding='same', - activation='elu', - layer_nb_feats=None, - use_residuals=False, - nb_conv_per_level=2, - conv_dropout=0, - batch_norm=None, - input_model=None): +def conv_enc( + nb_features, + input_shape, + nb_levels, + conv_size, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + dilation_rate_mult=1, + padding="same", + activation="elu", + layer_nb_feats=None, + use_residuals=False, + nb_conv_per_level=2, + conv_dropout=0, + batch_norm=None, + input_model=None, +): """Fully Convolutional Encoder""" # naming @@ -278,7 +303,7 @@ def conv_enc(nb_features, prefix = model_name # first layer: input - name = '%s_input' % prefix + name = "%s_input" % prefix if input_model is None: input_tensor = KL.Input(shape=input_shape, name=name) last_tensor = input_tensor @@ -294,34 +319,46 @@ def conv_enc(nb_features, pool_size = (pool_size,) * ndims # prepare layers - convL = getattr(KL, 'Conv%dD' % ndims) - conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'} - maxpool = getattr(KL, 'MaxPooling%dD' % ndims) + convL = getattr(KL, "Conv%dD" % ndims) + conv_kwargs = { + "padding": padding, + "activation": activation, + "data_format": "channels_last", + } + maxpool = getattr(KL, "MaxPooling%dD" % ndims) # down arm: # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers lfidx = 0 # level feature index for level in range(nb_levels): lvl_first_tensor = last_tensor - nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int) - conv_kwargs['dilation_rate'] = dilation_rate_mult ** level + nb_lvl_feats = np.round(nb_features * feat_mult**level).astype(int) + conv_kwargs["dilation_rate"] = dilation_rate_mult**level - for conv in range(nb_conv_per_level): # does several conv per level, max pooling applied at the end + for conv in range( + nb_conv_per_level + ): # does several conv per level, max pooling applied at the end if layer_nb_feats is not None: # None or List of all the feature numbers nb_lvl_feats = layer_nb_feats[lfidx] lfidx += 1 - name = '%s_conv_downarm_%d_%d' % (prefix, level, conv) + name = "%s_conv_downarm_%d_%d" % (prefix, level, conv) if conv < (nb_conv_per_level - 1) or (not use_residuals): - last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)( + last_tensor + ) else: # no activation - last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) + last_tensor = convL( + nb_lvl_feats, conv_size, padding=padding, name=name + )(last_tensor) if conv_dropout > 0: # conv dropout along feature space only - name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv) + name = "%s_dropout_downarm_%d_%d" % (prefix, level, conv) noise_shape = [None, *[1] * ndims, nb_lvl_feats] - last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) + last_tensor = KL.Dropout( + conv_dropout, noise_shape=noise_shape, name=name + )(last_tensor) if use_residuals: convarm_layer = last_tensor @@ -332,55 +369,63 @@ def conv_enc(nb_features, nb_feats_out = convarm_layer.get_shape()[-1] add_layer = lvl_first_tensor if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): - name = '%s_expand_down_merge_%d' % (prefix, level) - last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor) + name = "%s_expand_down_merge_%d" % (prefix, level) + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)( + lvl_first_tensor + ) add_layer = last_tensor if conv_dropout > 0: noise_shape = [None, *[1] * ndims, nb_lvl_feats] - convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) + convarm_layer = KL.Dropout(conv_dropout, noise_shape=noise_shape)( + last_tensor + ) - name = '%s_res_down_merge_%d' % (prefix, level) + name = "%s_res_down_merge_%d" % (prefix, level) last_tensor = KL.add([add_layer, convarm_layer], name=name) - name = '%s_res_down_merge_act_%d' % (prefix, level) + name = "%s_res_down_merge_act_%d" % (prefix, level) last_tensor = KL.Activation(activation, name=name)(last_tensor) if batch_norm is not None: - name = '%s_bn_down_%d' % (prefix, level) + name = "%s_bn_down_%d" % (prefix, level) last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) # max pool if we're not at the last level if level < (nb_levels - 1): - name = '%s_maxpool_%d' % (prefix, level) - last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor) + name = "%s_maxpool_%d" % (prefix, level) + last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)( + last_tensor + ) # create the model and return model = Model(inputs=input_tensor, outputs=[last_tensor], name=model_name) return model -def conv_dec(nb_features, - input_shape, - nb_levels, - conv_size, - nb_labels, - name=None, - prefix=None, - feat_mult=1, - pool_size=2, - use_skip_connections=False, - skip_n_concatenations=0, - padding='same', - dilation_rate_mult=1, - activation='elu', - use_residuals=False, - final_pred_activation='softmax', - nb_conv_per_level=2, - layer_nb_feats=None, - batch_norm=None, - conv_dropout=0, - input_model=None): +def conv_dec( + nb_features, + input_shape, + nb_levels, + conv_size, + nb_labels, + name=None, + prefix=None, + feat_mult=1, + pool_size=2, + use_skip_connections=False, + skip_n_concatenations=0, + padding="same", + dilation_rate_mult=1, + activation="elu", + use_residuals=False, + final_pred_activation="softmax", + nb_conv_per_level=2, + layer_nb_feats=None, + batch_norm=None, + conv_dropout=0, + input_model=None, +): """Fully Convolutional Decoder""" # naming @@ -390,10 +435,12 @@ def conv_dec(nb_features, # if using skip connections, make sure need to use them. if use_skip_connections: - assert input_model is not None, "is using skip connections, tensors dictionary is required" + assert ( + input_model is not None + ), "is using skip connections, tensors dictionary is required" # first layer: input - input_name = '%s_input' % prefix + input_name = "%s_input" % prefix if input_model is None: input_tensor = KL.Input(shape=input_shape, name=input_name) last_tensor = input_tensor @@ -409,29 +456,37 @@ def conv_dec(nb_features, pool_size = (pool_size,) * ndims # prepare layers - convL = getattr(KL, 'Conv%dD' % ndims) - conv_kwargs = {'padding': padding, 'activation': activation} - upsample = getattr(KL, 'UpSampling%dD' % ndims) + convL = getattr(KL, "Conv%dD" % ndims) + conv_kwargs = {"padding": padding, "activation": activation} + upsample = getattr(KL, "UpSampling%dD" % ndims) # up arm: # nb_levels - 1 layers of Deconvolution3D # (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu lfidx = 0 for level in range(nb_levels - 1): - nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int) - conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level) + nb_lvl_feats = np.round( + nb_features * feat_mult ** (nb_levels - 2 - level) + ).astype(int) + conv_kwargs["dilation_rate"] = dilation_rate_mult ** (nb_levels - 2 - level) # upsample matching the max pooling layers size - name = '%s_up_%d' % (prefix, nb_levels + level) + name = "%s_up_%d" % (prefix, nb_levels + level) last_tensor = upsample(size=pool_size, name=name)(last_tensor) up_tensor = last_tensor # merge layers combining previous layer if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)): - conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1) + conv_name = "%s_conv_downarm_%d_%d" % ( + prefix, + nb_levels - 2 - level, + nb_conv_per_level - 1, + ) cat_tensor = input_model.get_layer(conv_name).output - name = '%s_merge_%d' % (prefix, nb_levels + level) - last_tensor = KL.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name) + name = "%s_merge_%d" % (prefix, nb_levels + level) + last_tensor = KL.concatenate( + [cat_tensor, last_tensor], axis=ndims + 1, name=name + ) # convolution layers for conv in range(nb_conv_per_level): @@ -439,16 +494,22 @@ def conv_dec(nb_features, nb_lvl_feats = layer_nb_feats[lfidx] lfidx += 1 - name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv) + name = "%s_conv_uparm_%d_%d" % (prefix, nb_levels + level, conv) if conv < (nb_conv_per_level - 1) or (not use_residuals): - last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor) + last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)( + last_tensor + ) else: - last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor) + last_tensor = convL( + nb_lvl_feats, conv_size, padding=padding, name=name + )(last_tensor) if conv_dropout > 0: - name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv) + name = "%s_dropout_uparm_%d_%d" % (prefix, level, conv) noise_shape = [None, *[1] * ndims, nb_lvl_feats] - last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor) + last_tensor = KL.Dropout( + conv_dropout, noise_shape=noise_shape, name=name + )(last_tensor) # residual block if use_residuals: @@ -459,51 +520,57 @@ def conv_dec(nb_features, nb_feats_in = add_layer.get_shape()[-1] nb_feats_out = last_tensor.get_shape()[-1] if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out): - name = '%s_expand_up_merge_%d' % (prefix, level) - add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer) + name = "%s_expand_up_merge_%d" % (prefix, level) + add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)( + add_layer + ) if conv_dropout > 0: noise_shape = [None, *[1] * ndims, nb_lvl_feats] - last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor) + last_tensor = KL.Dropout(conv_dropout, noise_shape=noise_shape)( + last_tensor + ) - name = '%s_res_up_merge_%d' % (prefix, level) + name = "%s_res_up_merge_%d" % (prefix, level) last_tensor = KL.add([last_tensor, add_layer], name=name) - name = '%s_res_up_merge_act_%d' % (prefix, level) + name = "%s_res_up_merge_act_%d" % (prefix, level) last_tensor = KL.Activation(activation, name=name)(last_tensor) if batch_norm is not None: - name = '%s_bn_up_%d' % (prefix, level) + name = "%s_bn_up_%d" % (prefix, level) last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) # Compute likelihood prediction (no activation yet) - name = '%s_likelihood' % prefix + name = "%s_likelihood" % prefix last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor) like_tensor = last_tensor # output prediction layer # we use a softmax to compute P(L_x|I) where x is each location - if final_pred_activation == 'softmax': - name = '%s_prediction' % prefix + if final_pred_activation == "softmax": + name = "%s_prediction" % prefix softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1) pred_tensor = KL.Lambda(softmax_lambda_fcn, name=name)(last_tensor) # otherwise create a layer that does nothing. else: - name = '%s_prediction' % prefix - pred_tensor = KL.Activation('linear', name=name)(like_tensor) + name = "%s_prediction" % prefix + pred_tensor = KL.Activation("linear", name=name)(like_tensor) # create the model and return model = Model(inputs=input_tensor, outputs=pred_tensor, name=model_name) return model -def add_prior(input_model, - prior_shape, - name='prior_model', - prefix=None, - use_logp=True, - final_pred_activation='softmax'): +def add_prior( + input_model, + prior_shape, + name="prior_model", + prefix=None, + use_logp=True, + final_pred_activation="softmax", +): """ Append post-prior layer to a given model """ @@ -514,38 +581,43 @@ def add_prior(input_model, prefix = model_name # prior input layer - prior_input_name = '%s-input' % prefix + prior_input_name = "%s-input" % prefix prior_tensor = KL.Input(shape=prior_shape, name=prior_input_name) prior_tensor_input = prior_tensor like_tensor = input_model.output # operation varies depending on whether we log() prior or not. if use_logp: - print("Breaking change: use_logp option now requires log input!", file=sys.stderr) + print( + "Breaking change: use_logp option now requires log input!", file=sys.stderr + ) merge_op = KL.add else: # using sigmoid to get the likelihood values between 0 and 1 # note: they won't add up to 1. - name = '%s_likelihood_sigmoid' % prefix - like_tensor = KL.Activation('sigmoid', name=name)(like_tensor) + name = "%s_likelihood_sigmoid" % prefix + like_tensor = KL.Activation("sigmoid", name=name)(like_tensor) merge_op = KL.multiply # merge the likelihood and prior layers into posterior layer - name = '%s_posterior' % prefix + name = "%s_posterior" % prefix post_tensor = merge_op([prior_tensor, like_tensor], name=name) # output prediction layer # we use a softmax to compute P(L_x|I) where x is each location - pred_name = '%s_prediction' % prefix - if final_pred_activation == 'softmax': - assert use_logp, 'cannot do softmax when adding prior via P()' - print("using final_pred_activation %s for %s" % (final_pred_activation, model_name)) + pred_name = "%s_prediction" % prefix + if final_pred_activation == "softmax": + assert use_logp, "cannot do softmax when adding prior via P()" + print( + "using final_pred_activation %s for %s" + % (final_pred_activation, model_name) + ) softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=-1) pred_tensor = KL.Lambda(softmax_lambda_fcn, name=pred_name)(post_tensor) else: - pred_tensor = KL.Activation('linear', name=pred_name)(post_tensor) + pred_tensor = KL.Activation("linear", name=pred_name)(post_tensor) # create the model model_inputs = [*input_model.inputs, prior_tensor_input] @@ -555,19 +627,21 @@ def add_prior(input_model, return model -def single_ae(enc_size, - input_shape, - name='single_ae', - prefix=None, - ae_type='dense', # 'dense', or 'conv' - conv_size=None, - input_model=None, - enc_lambda_layers=None, - batch_norm=True, - padding='same', - activation=None, - include_mu_shift_layer=False, - do_vae=False): +def single_ae( + enc_size, + input_shape, + name="single_ae", + prefix=None, + ae_type="dense", # 'dense', or 'conv' + conv_size=None, + input_model=None, + enc_lambda_layers=None, + batch_norm=True, + padding="same", + activation=None, + include_mu_shift_layer=False, + do_vae=False, +): """single-layer Autoencoder (i.e. input - encoding - output""" # naming @@ -579,9 +653,9 @@ def single_ae(enc_size, enc_lambda_layers = [] # prepare input - input_name = '%s_input' % prefix + input_name = "%s_input" % prefix if input_model is None: - assert input_shape is not None, 'input_shape of input_model is necessary' + assert input_shape is not None, "input_shape of input_model is necessary" input_tensor = KL.Input(shape=input_shape, name=input_name) last_tensor = input_tensor else: @@ -592,10 +666,10 @@ def single_ae(enc_size, # prepare conv type based on input ndims = len(input_shape) - 1 - if ae_type == 'conv': - convL = getattr(KL, 'Conv%dD' % ndims) - assert conv_size is not None, 'with conv ae, need conv_size' - conv_kwargs = {'padding': padding, 'activation': activation} + if ae_type == "conv": + convL = getattr(KL, "Conv%dD" % ndims) + assert conv_size is not None, "with conv ae, need conv_size" + conv_kwargs = {"padding": padding, "activation": activation} enc_size_str = None # if want to go through a dense layer in the middle of the U, need to: @@ -604,62 +678,75 @@ def single_ae(enc_size, # - unflatten (reshape spatially) at end else: # ae_type == 'dense' if len(input_shape) > 1: - name = '%s_ae_%s_down_flat' % (prefix, ae_type) + name = "%s_ae_%s_down_flat" % (prefix, ae_type) last_tensor = KL.Flatten(name=name)(last_tensor) convL = conv_kwargs = None assert len(enc_size) == 1, "enc_size should be of length 1 for dense layer" - enc_size_str = ''.join(['%d_' % d for d in enc_size])[:-1] + enc_size_str = "".join(["%d_" % d for d in enc_size])[:-1] # recall this layer pre_enc_layer = last_tensor # encoding layer - if ae_type == 'dense': - name = '%s_ae_mu_enc_dense_%s' % (prefix, enc_size_str) + if ae_type == "dense": + name = "%s_ae_mu_enc_dense_%s" % (prefix, enc_size_str) last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) else: # convolution # convolve then resize. enc_size should be [nb_dim1, nb_dim2, ..., nb_feats] - assert len(enc_size) == len(input_shape), \ - "encoding size does not match input shape %d %d" % (len(enc_size), len(input_shape)) - - if list(enc_size)[:-1] != list(input_shape)[:-1] and \ - all([f is not None for f in input_shape[:-1]]) and \ - all([f is not None for f in enc_size[:-1]]): - - name = '%s_ae_mu_enc_conv' % prefix - last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) - - name = '%s_ae_mu_enc' % prefix - zf = [enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] for f in range(len(enc_size) - 1)] + assert len(enc_size) == len( + input_shape + ), "encoding size does not match input shape %d %d" % ( + len(enc_size), + len(input_shape), + ) + + if ( + list(enc_size)[:-1] != list(input_shape)[:-1] + and all([f is not None for f in input_shape[:-1]]) + and all([f is not None for f in enc_size[:-1]]) + ): + + name = "%s_ae_mu_enc_conv" % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer + ) + + name = "%s_ae_mu_enc" % prefix + zf = [ + enc_size[:-1][f] / last_tensor.shape.as_list()[1:-1][f] + for f in range(len(enc_size) - 1) + ] last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck - name = '%s_ae_mu_enc' % prefix + name = "%s_ae_mu_enc" % prefix last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) else: - name = '%s_ae_mu_enc' % prefix - last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + name = "%s_ae_mu_enc" % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer + ) if include_mu_shift_layer: # shift - name = '%s_ae_mu_shift' % prefix + name = "%s_ae_mu_shift" % prefix last_tensor = layers.LocalBias(name=name)(last_tensor) # encoding clean-up layers for layer_fcn in enc_lambda_layers: lambda_name = layer_fcn.__name__ - name = '%s_ae_mu_%s' % (prefix, lambda_name) + name = "%s_ae_mu_%s" % (prefix, lambda_name) last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) if batch_norm is not None: - name = '%s_ae_mu_bn' % prefix + name = "%s_ae_mu_bn" % prefix last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) # have a simple layer that does nothing to have a clear name before sampling - name = '%s_ae_mu' % prefix + name = "%s_ae_mu" % prefix last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) # if doing variational AE, will need the sigma layer as well. @@ -667,46 +754,58 @@ def single_ae(enc_size, mu_tensor = last_tensor # encoding layer - if ae_type == 'dense': - name = '%s_ae_sigma_enc_dense_%s' % (prefix, enc_size_str) + if ae_type == "dense": + name = "%s_ae_sigma_enc_dense_%s" % (prefix, enc_size_str) last_tensor = KL.Dense(enc_size[0], name=name)(pre_enc_layer) else: - if list(enc_size)[:-1] != list(input_shape)[:-1] and \ - all([f is not None for f in input_shape[:-1]]) and \ - all([f is not None for f in enc_size[:-1]]): - - assert len(enc_size) - 1 == 2, "Sorry, I have not yet implemented non-2D resizing..." - name = '%s_ae_sigma_enc_conv' % prefix - last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) - - name = '%s_ae_sigma_enc' % prefix + if ( + list(enc_size)[:-1] != list(input_shape)[:-1] + and all([f is not None for f in input_shape[:-1]]) + and all([f is not None for f in enc_size[:-1]]) + ): + + assert ( + len(enc_size) - 1 == 2 + ), "Sorry, I have not yet implemented non-2D resizing..." + name = "%s_ae_sigma_enc_conv" % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer + ) + + name = "%s_ae_sigma_enc" % prefix resize_fn = lambda x: tf.image.resize_bilinear(x, enc_size[:-1]) last_tensor = KL.Lambda(resize_fn, name=name)(last_tensor) elif enc_size[-1] is None: # convolutional, but won't tell us bottleneck - name = '%s_ae_sigma_enc' % prefix - last_tensor = convL(pre_enc_layer.shape.as_list()[-1], conv_size, name=name, **conv_kwargs)( - pre_enc_layer) + name = "%s_ae_sigma_enc" % prefix + last_tensor = convL( + pre_enc_layer.shape.as_list()[-1], + conv_size, + name=name, + **conv_kwargs + )(pre_enc_layer) # cannot use lambda, then mu and sigma will be same layer. # last_tensor = KL.Lambda(lambda x: x, name=name)(pre_enc_layer) else: - name = '%s_ae_sigma_enc' % prefix - last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)(pre_enc_layer) + name = "%s_ae_sigma_enc" % prefix + last_tensor = convL(enc_size[-1], conv_size, name=name, **conv_kwargs)( + pre_enc_layer + ) # encoding clean-up layers for layer_fcn in enc_lambda_layers: lambda_name = layer_fcn.__name__ - name = '%s_ae_sigma_%s' % (prefix, lambda_name) + name = "%s_ae_sigma_%s" % (prefix, lambda_name) last_tensor = KL.Lambda(layer_fcn, name=name)(last_tensor) if batch_norm is not None: - name = '%s_ae_sigma_bn' % prefix + name = "%s_ae_sigma_bn" % prefix last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) # have a simple layer that does nothing to have a clear name before sampling - name = '%s_ae_sigma' % prefix + name = "%s_ae_sigma" % prefix last_tensor = KL.Lambda(lambda x: x, name=name)(last_tensor) logvar_tensor = last_tensor @@ -714,38 +813,45 @@ def single_ae(enc_size, # VAE sampling sampler = _VAESample().sample_z - name = '%s_ae_sample' % prefix + name = "%s_ae_sample" % prefix last_tensor = KL.Lambda(sampler, name=name)([mu_tensor, logvar_tensor]) if include_mu_shift_layer: # shift - name = '%s_ae_sample_shift' % prefix + name = "%s_ae_sample_shift" % prefix last_tensor = layers.LocalBias(name=name)(last_tensor) # decoding layer - if ae_type == 'dense': - name = '%s_ae_%s_dec_flat_%s' % (prefix, ae_type, enc_size_str) + if ae_type == "dense": + name = "%s_ae_%s_dec_flat_%s" % (prefix, ae_type, enc_size_str) last_tensor = KL.Dense(np.prod(input_shape), name=name)(last_tensor) # unflatten if dense method if len(input_shape) > 1: - name = '%s_ae_%s_dec' % (prefix, ae_type) + name = "%s_ae_%s_dec" % (prefix, ae_type) last_tensor = KL.Reshape(input_shape, name=name)(last_tensor) else: - if list(enc_size)[:-1] != list(input_shape)[:-1] and \ - all([f is not None for f in input_shape[:-1]]) and \ - all([f is not None for f in enc_size[:-1]]): - name = '%s_ae_mu_dec' % prefix - zf = [last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] for f in range(len(enc_size) - 1)] + if ( + list(enc_size)[:-1] != list(input_shape)[:-1] + and all([f is not None for f in input_shape[:-1]]) + and all([f is not None for f in enc_size[:-1]]) + ): + name = "%s_ae_mu_dec" % prefix + zf = [ + last_tensor.shape.as_list()[1:-1][f] / enc_size[:-1][f] + for f in range(len(enc_size) - 1) + ] last_tensor = layers.Resize(zoom_factor=zf, name=name)(last_tensor) - name = '%s_ae_%s_dec' % (prefix, ae_type) - last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)(last_tensor) + name = "%s_ae_%s_dec" % (prefix, ae_type) + last_tensor = convL(input_nb_feats, conv_size, name=name, **conv_kwargs)( + last_tensor + ) if batch_norm is not None: - name = '%s_bn_ae_%s_dec' % (prefix, ae_type) + name = "%s_bn_ae_%s_dec" % (prefix, ae_type) last_tensor = KL.BatchNormalization(axis=batch_norm, name=name)(last_tensor) # create the model and return @@ -757,6 +863,7 @@ def single_ae(enc_size, # Helper function ############################################################################### + class _VAESample: def __init__(self): pass @@ -764,5 +871,5 @@ def __init__(self): def sample_z(self, args): mu, log_var = args shape = K.shape(mu) - eps = K.random_normal(shape=shape, mean=0., stddev=1.) + eps = K.random_normal(shape=shape, mean=0.0, stddev=1.0) return mu + K.exp(log_var / 2) * eps diff --git a/nobrainer/ext/neuron/utils.py b/nobrainer/ext/neuron/utils.py index 1162b51c..c6b94028 100644 --- a/nobrainer/ext/neuron/utils.py +++ b/nobrainer/ext/neuron/utils.py @@ -1,9 +1,9 @@ """ tensorflow/keras utilities for the neuron project -If you use this code, please cite +If you use this code, please cite Dalca AV, Guttag J, Sabuncu MR -Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, +Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, CVPR 2018 or for the transformation/interpolation related functions: @@ -17,16 +17,17 @@ """ import itertools + +import keras.backend as K import numpy as np import tensorflow as tf -import keras.backend as K -def interpn(vol, loc, interp_method='linear'): +def interpn(vol, loc, interp_method="linear"): """ N-D gridded interpolation in tensorflow - vol can have more dimensions than loc[i], in which case loc[i] acts as a slice + vol can have more dimensions than loc[i], in which case loc[i] acts as a slice for the first dimensions Parameters: @@ -45,18 +46,22 @@ def interpn(vol, loc, interp_method='linear'): nb_dims = loc.shape[-1] if len(vol.shape) not in [nb_dims, nb_dims + 1]: - raise Exception("Number of loc Tensors %d does not match volume dimension %d" - % (nb_dims, len(vol.shape[:-1]))) + raise Exception( + "Number of loc Tensors %d does not match volume dimension %d" + % (nb_dims, len(vol.shape[:-1])) + ) if nb_dims > len(vol.shape): - raise Exception("Loc dimension %d does not match volume dimension %d" - % (nb_dims, len(vol.shape))) + raise Exception( + "Loc dimension %d does not match volume dimension %d" + % (nb_dims, len(vol.shape)) + ) if len(vol.shape) == nb_dims: vol = K.expand_dims(vol, -1) # flatten and float location Tensors - loc = tf.cast(loc, 'float32') + loc = tf.cast(loc, "float32") if isinstance(vol.shape, tf.TensorShape): volshape = vol.shape.as_list() @@ -64,26 +69,36 @@ def interpn(vol, loc, interp_method='linear'): volshape = vol.shape # interpolate - if interp_method == 'linear': + if interp_method == "linear": loc0 = tf.floor(loc) # clip values max_loc = [d - 1 for d in vol.get_shape().as_list()] - clipped_loc = [tf.clip_by_value(loc[..., d], 0, max_loc[d]) for d in range(nb_dims)] - loc0lst = [tf.clip_by_value(loc0[..., d], 0, max_loc[d]) for d in range(nb_dims)] + clipped_loc = [ + tf.clip_by_value(loc[..., d], 0, max_loc[d]) for d in range(nb_dims) + ] + loc0lst = [ + tf.clip_by_value(loc0[..., d], 0, max_loc[d]) for d in range(nb_dims) + ] # get other end of point cube loc1 = [tf.clip_by_value(loc0lst[d] + 1, 0, max_loc[d]) for d in range(nb_dims)] - locs = [[tf.cast(f, 'int32') for f in loc0lst], [tf.cast(f, 'int32') for f in loc1]] + locs = [ + [tf.cast(f, "int32") for f in loc0lst], + [tf.cast(f, "int32") for f in loc1], + ] # compute the difference between the upper value and the original value # differences are basically 1 - (pt - floor(pt)) # because: floor(pt) + 1 - pt = 1 + (floor(pt) - pt) = 1 - (pt - floor(pt)) diff_loc1 = [loc1[d] - clipped_loc[d] for d in range(nb_dims)] diff_loc0 = [1 - d for d in diff_loc1] - weights_loc = [diff_loc1, diff_loc0] # note reverse ordering since weights are inverse of diff. + weights_loc = [ + diff_loc1, + diff_loc0, + ] # note reverse ordering since weights are inverse of diff. - # go through all the cube corners, indexed by a ND binary vector + # go through all the cube corners, indexed by a ND binary vector # e.g. [0, 0] means this "first" corner in a 2-D "cube" cube_pts = list(itertools.product([0, 1], repeat=nb_dims)) interp_vol = 0 @@ -110,12 +125,14 @@ def interpn(vol, loc, interp_method='linear'): interp_vol += wt * vol_val else: - assert interp_method == 'nearest' - roundloc = tf.cast(tf.round(loc), 'int32') + assert interp_method == "nearest" + roundloc = tf.cast(tf.round(loc), "int32") # clip values - max_loc = [tf.cast(d - 1, 'int32') for d in vol.shape] - roundloc = [tf.clip_by_value(roundloc[..., d], 0, max_loc[d]) for d in range(nb_dims)] + max_loc = [tf.cast(d - 1, "int32") for d in vol.shape] + roundloc = [ + tf.clip_by_value(roundloc[..., d], 0, max_loc[d]) for d in range(nb_dims) + ] # get values idx = sub2ind(vol.shape[:-1], roundloc) @@ -124,7 +141,7 @@ def interpn(vol, loc, interp_method='linear'): return interp_vol -def resize(vol, zoom_factor, new_shape, interp_method='linear'): +def resize(vol, zoom_factor, new_shape, interp_method="linear"): """ if zoom_factor is a list, it will determine the ndims, in which case vol has to be of length ndims or ndims + 1 @@ -137,8 +154,10 @@ def resize(vol, zoom_factor, new_shape, interp_method='linear'): if isinstance(zoom_factor, (list, tuple)): ndims = len(zoom_factor) vol_shape = vol.shape[:ndims] - assert len(vol_shape) in (ndims, ndims + 1), \ - "zoom_factor length %d does not match ndims %d" % (len(vol_shape), ndims) + assert len(vol_shape) in ( + ndims, + ndims + 1, + ), "zoom_factor length %d does not match ndims %d" % (len(vol_shape), ndims) else: vol_shape = vol.shape[:-1] ndims = len(vol_shape) @@ -146,7 +165,7 @@ def resize(vol, zoom_factor, new_shape, interp_method='linear'): # get grid for new shape grid = volshape_to_ndgrid(new_shape) - grid = [tf.cast(f, 'float32') for f in grid] + grid = [tf.cast(f, "float32") for f in grid] offset = [grid[f] / zoom_factor[f] - grid[f] for f in range(ndims)] offset = tf.stack(offset, ndims) @@ -157,7 +176,7 @@ def resize(vol, zoom_factor, new_shape, interp_method='linear'): zoom = resize -def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing='ij'): +def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing="ij"): """ transform an affine matrix to a dense location shift tensor in tensorflow @@ -179,35 +198,43 @@ def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing='ij'): if isinstance(volshape, tf.TensorShape): volshape = volshape.as_list() - if affine_matrix.dtype != 'float32': - affine_matrix = tf.cast(affine_matrix, 'float32') + if affine_matrix.dtype != "float32": + affine_matrix = tf.cast(affine_matrix, "float32") nb_dims = len(volshape) if len(affine_matrix.shape) == 1: if len(affine_matrix) != (nb_dims * (nb_dims + 1)): - raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).' - 'Got len %d' % len(affine_matrix)) + raise ValueError( + "transform is supposed a vector of len ndims * (ndims + 1)." + "Got len %d" % len(affine_matrix) + ) affine_matrix = tf.reshape(affine_matrix, [nb_dims, nb_dims + 1]) - if not (affine_matrix.shape[0] in [nb_dims, nb_dims + 1] and affine_matrix.shape[1] == (nb_dims + 1)): - raise Exception('Affine matrix shape should match' - '%d+1 x %d+1 or ' % (nb_dims, nb_dims) + - '%d x %d+1.' % (nb_dims, nb_dims) + - 'Got: ' + str(volshape)) + if not ( + affine_matrix.shape[0] in [nb_dims, nb_dims + 1] + and affine_matrix.shape[1] == (nb_dims + 1) + ): + raise Exception( + "Affine matrix shape should match" + "%d+1 x %d+1 or " % (nb_dims, nb_dims) + + "%d x %d+1." % (nb_dims, nb_dims) + + "Got: " + + str(volshape) + ) # list of volume ndgrid # N-long list, each entry of shape volshape mesh = volshape_to_meshgrid(volshape, indexing=indexing) - mesh = [tf.cast(f, 'float32') for f in mesh] + mesh = [tf.cast(f, "float32") for f in mesh] if shift_center: mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] # add an all-ones entry and transform into a large matrix flat_mesh = [flatten(f) for f in mesh] - flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32')) + flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype="float32")) mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # 4 x nb_voxels # compute locations @@ -219,7 +246,9 @@ def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing='ij'): return loc - tf.stack(mesh, axis=nb_dims) -def combine_non_linear_and_aff_to_shift(transform_list, volshape, shift_center=True, indexing='ij'): +def combine_non_linear_and_aff_to_shift( + transform_list, volshape, shift_center=True, indexing="ij" +): """ transform an affine matrix to a dense location shift tensor in tensorflow @@ -243,29 +272,37 @@ def combine_non_linear_and_aff_to_shift(transform_list, volshape, shift_center=T # convert transforms to floats for i in range(len(transform_list)): - if transform_list[i].dtype != 'float32': - transform_list[i] = tf.cast(transform_list[i], 'float32') + if transform_list[i].dtype != "float32": + transform_list[i] = tf.cast(transform_list[i], "float32") nb_dims = len(volshape) # transform affine to matrix if given as vector if len(transform_list[1].shape) == 1: if len(transform_list[1]) != (nb_dims * (nb_dims + 1)): - raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).' - 'Got len %d' % len(transform_list[1])) + raise ValueError( + "transform is supposed a vector of len ndims * (ndims + 1)." + "Got len %d" % len(transform_list[1]) + ) transform_list[1] = tf.reshape(transform_list[1], [nb_dims, nb_dims + 1]) - if not (transform_list[1].shape[0] in [nb_dims, nb_dims + 1] and transform_list[1].shape[1] == (nb_dims + 1)): - raise Exception('Affine matrix shape should match' - '%d+1 x %d+1 or ' % (nb_dims, nb_dims) + - '%d x %d+1.' % (nb_dims, nb_dims) + - 'Got: ' + str(volshape)) + if not ( + transform_list[1].shape[0] in [nb_dims, nb_dims + 1] + and transform_list[1].shape[1] == (nb_dims + 1) + ): + raise Exception( + "Affine matrix shape should match" + "%d+1 x %d+1 or " % (nb_dims, nb_dims) + + "%d x %d+1." % (nb_dims, nb_dims) + + "Got: " + + str(volshape) + ) # list of volume ndgrid # N-long list, each entry of shape volshape mesh = volshape_to_meshgrid(volshape, indexing=indexing) - mesh = [tf.cast(f, 'float32') for f in mesh] + mesh = [tf.cast(f, "float32") for f in mesh] if shift_center: mesh = [mesh[f] - (volshape[f] - 1) / 2 for f in range(len(volshape))] @@ -273,8 +310,8 @@ def combine_non_linear_and_aff_to_shift(transform_list, volshape, shift_center=T # add an all-ones entry and transform into a large matrix # non_linear_mesh = tf.unstack(transform_list[0], axis=3) non_linear_mesh = tf.unstack(transform_list[0], axis=-1) - flat_mesh = [flatten(mesh[i]+non_linear_mesh[i]) for i in range(len(mesh))] - flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32')) + flat_mesh = [flatten(mesh[i] + non_linear_mesh[i]) for i in range(len(mesh))] + flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype="float32")) mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1)) # N+1 x nb_voxels # compute locations @@ -286,12 +323,12 @@ def combine_non_linear_and_aff_to_shift(transform_list, volshape, shift_center=T return loc - tf.stack(mesh, axis=nb_dims) -def transform(vol, loc_shift, interp_method='linear', indexing='ij'): +def transform(vol, loc_shift, interp_method="linear", indexing="ij"): """ transform interpolation N-D volumes (features) given shifts at each location in tensorflow - Essentially interpolates volume vol at locations determined by loc_shift. - This is a spatial transform in the sense that at location [x] we now have the data from, + Essentially interpolates volume vol at locations determined by loc_shift. + This is a spatial transform in the sense that at location [x] we now have the data from, [x + shift] so we've moved data. Parameters: @@ -300,7 +337,7 @@ def transform(vol, loc_shift, interp_method='linear', indexing='ij'): interp_method (default:'linear'): 'linear', 'nearest' indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian). In general, prefer to leave this 'ij' - + Return: new interpolated volumes in the same size as loc_shift[0] """ @@ -314,29 +351,29 @@ def transform(vol, loc_shift, interp_method='linear', indexing='ij'): # location should be meshed and delta mesh = volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh - loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)] + loc = [tf.cast(mesh[d], "float32") + loc_shift[..., d] for d in range(nb_dims)] # test single return interpn(vol, loc, interp_method=interp_method) -def integrate_vec(vec, time_dep=False, method='ss', **kwargs): +def integrate_vec(vec, time_dep=False, method="ss", **kwargs): """ Integrate (stationary of time-dependent) vector field (N-D Tensor) in tensorflow - - Aside from directly using tensorflow's numerical integration odeint(), also implements + + Aside from directly using tensorflow's numerical integration odeint(), also implements "scaling and squaring", and quadrature. Note that the diff. equation given to odeint - is the one used in quadrature. + is the one used in quadrature. Parameters: - vec: the Tensor field to integrate. + vec: the Tensor field to integrate. If vol_size is the size of the intrinsic volume, and vol_ndim = len(vol_size), - then vector shape (vec_shape) should be + then vector shape (vec_shape) should be [vol_size, vol_ndim] (if stationary) [vol_size, vol_ndim, nb_time_steps] (if time dependent) time_dep: bool whether vector is time dependent method: 'scaling_and_squaring' or 'ss' or 'quadrature' - + if using 'scaling_and_squaring': currently only supports integrating to time point 1. nb_steps int number of steps. Note that this means the vec field gets broken own to 2**nb_steps. so nb_steps of 0 means integral = vec. @@ -345,32 +382,36 @@ def integrate_vec(vec, time_dep=False, method='ss', **kwargs): int_vec: integral of vector field with same shape as the input """ - if method not in ['ss', 'scaling_and_squaring', 'ode', 'quadrature']: - raise ValueError("method has to be 'scaling_and_squaring' or 'ode'. found: %s" % method) + if method not in ["ss", "scaling_and_squaring", "ode", "quadrature"]: + raise ValueError( + "method has to be 'scaling_and_squaring' or 'ode'. found: %s" % method + ) - if method in ['ss', 'scaling_and_squaring']: - nb_steps = kwargs['nb_steps'] - assert nb_steps >= 0, 'nb_steps should be >= 0, found: %d' % nb_steps + if method in ["ss", "scaling_and_squaring"]: + nb_steps = kwargs["nb_steps"] + assert nb_steps >= 0, "nb_steps should be >= 0, found: %d" % nb_steps if time_dep: svec = K.permute_dimensions(vec, [-1, *range(0, vec.shape[-1] - 1)]) - assert 2 ** nb_steps == svec.shape[0], "2**nb_steps and vector shape don't match" + assert ( + 2**nb_steps == svec.shape[0] + ), "2**nb_steps and vector shape don't match" - svec = svec / (2 ** nb_steps) + svec = svec / (2**nb_steps) for _ in range(nb_steps): svec = svec[0::2] + tf.map_fn(transform, svec[1::2, :], svec[0::2, :]) disp = svec[0, :] else: - vec = vec / (2 ** nb_steps) + vec = vec / (2**nb_steps) for _ in range(nb_steps): vec += transform(vec, vec) disp = vec else: # method == 'quadrature': - nb_steps = kwargs['nb_steps'] - assert nb_steps >= 1, 'nb_steps should be >= 1, found: %d' % nb_steps + nb_steps = kwargs["nb_steps"] + assert nb_steps >= 1, "nb_steps should be >= 1, found: %d" % nb_steps vec = vec / nb_steps @@ -441,18 +482,18 @@ def ndgrid(*args, **kwargs): Returns: A list of Tensors - + """ - return meshgrid(*args, indexing='ij', **kwargs) + return meshgrid(*args, indexing="ij", **kwargs) def meshgrid(*args, **kwargs): """ - + meshgrid code that builds on (copies) tensorflow's meshgrid but dramatically improves runtime by changing the last step to tiling instead of multiplication. https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/python/ops/array_ops.py#L1921 - + Broadcasts parameters for evaluation on an N-D grid. Given N one-dimensional coordinate arrays `*args`, returns a list `outputs` of N-D coordinate arrays for evaluating expressions on an N-D grid. @@ -488,8 +529,9 @@ def meshgrid(*args, **kwargs): indexing = kwargs.pop("indexing", "xy") if kwargs: key = list(kwargs.keys())[0] - raise TypeError("'{}' is an invalid keyword argument " - "for this function".format(key)) + raise TypeError( + "'{}' is an invalid keyword argument " "for this function".format(key) + ) if indexing not in ("xy", "ij"): raise ValueError("indexing parameter must be either 'xy' or 'ij'") @@ -501,7 +543,7 @@ def meshgrid(*args, **kwargs): # Prepare reshape by inserting dimensions with size 1 where needed output = [] for i, x in enumerate(args): - output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::]))) + output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1 : :]))) # Create parameters for broadcasting each tensor to the full size shapes = [tf.size(x) for x in args] sz = [x.get_shape().as_list()[0] for x in args] @@ -514,8 +556,8 @@ def meshgrid(*args, **kwargs): sz[0], sz[1] = sz[1], sz[0] for i in range(len(output)): - stack_sz = [*sz[:i], 1, *sz[(i + 1):]] - if indexing == 'xy' and ndim > 1 and i < 2: + stack_sz = [*sz[:i], 1, *sz[(i + 1) :]] + if indexing == "xy" and ndim > 1 and i < 2: stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0] output[i] = tf.tile(output[i], tf.stack(stack_sz)) return output @@ -537,7 +579,10 @@ def prod_n(lst): def sub2ind(siz, subs): """assumes column-order major""" # subs is a list - assert len(siz) == len(subs), 'found inconsistent siz and subs: %d %d' % (len(siz), len(subs)) + assert len(siz) == len(subs), "found inconsistent siz and subs: %d %d" % ( + len(siz), + len(subs), + ) k = np.cumprod(siz[::-1]) diff --git a/nobrainer/models/__init__.py b/nobrainer/models/__init__.py index 27d83502..23420f21 100644 --- a/nobrainer/models/__init__.py +++ b/nobrainer/models/__init__.py @@ -7,12 +7,12 @@ from .bayesian_vnet import bayesian_vnet from .dcgan import dcgan from .highresnet import highresnet +from .labels_to_image_model import labels_to_image_model from .meshnet import meshnet from .progressiveae import progressiveae from .progressivegan import progressivegan from .unet import unet from .unetr import unetr -from .labels_to_image_model import labels_to_image_model __all__ = ["get", "list_available_models"] @@ -29,7 +29,7 @@ "unetr": unetr, "variational_meshnet": variational_meshnet, "bayesian_vnet": bayesian_vnet, - "synthgenerator": labels_to_image_model + "synthgenerator": labels_to_image_model, } diff --git a/nobrainer/models/labels_to_image_model.py b/nobrainer/models/labels_to_image_model.py index 0083b2d5..4031b297 100644 --- a/nobrainer/models/labels_to_image_model.py +++ b/nobrainer/models/labels_to_image_model.py @@ -13,45 +13,46 @@ License. """ +import keras.layers as KL +from keras.models import Model # python imports import numpy as np import tensorflow as tf -import keras.layers as KL -from keras.models import Model # third-party imports -from nobrainer.ext.lab2im import utils -from nobrainer.ext.lab2im import layers from nobrainer.ext.lab2im import edit_tensors as l2i_et +from nobrainer.ext.lab2im import layers, utils from nobrainer.ext.lab2im.edit_volumes import get_ras_axes -def labels_to_image_model(labels_shape, - n_channels, - generation_labels, - output_labels, - n_neutral_labels, - atlas_res, - target_res, - output_shape=None, - output_div_by_n=None, - flipping=True, - aff=None, - scaling_bounds=0.2, - rotation_bounds=15, - shearing_bounds=0.012, - translation_bounds=False, - nonlin_std=3., - nonlin_scale=.0625, - randomise_res=False, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.5, - bias_scale=.025, - return_gradients=False): +def labels_to_image_model( + labels_shape, + n_channels, + generation_labels, + output_labels, + n_neutral_labels, + atlas_res, + target_res, + output_shape=None, + output_div_by_n=None, + flipping=True, + aff=None, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=3.0, + nonlin_scale=0.0625, + randomise_res=False, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.5, + bias_scale=0.025, + return_gradients=False, +): """ This function builds a keras/tensorflow model to generate images from provided label maps. The images are generated by sampling a Gaussian Mixture Model (of given parameters), conditioned on the label map. @@ -147,28 +148,50 @@ def labels_to_image_model(labels_shape, labels_shape = utils.reformat_to_list(labels_shape) n_dims, _ = utils.get_dims(labels_shape) atlas_res = utils.reformat_to_n_channels_array(atlas_res, n_dims, n_channels) - data_res = atlas_res if data_res is None else utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) - thickness = data_res if thickness is None else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + data_res = ( + atlas_res + if data_res is None + else utils.reformat_to_n_channels_array(data_res, n_dims, n_channels) + ) + thickness = ( + data_res + if thickness is None + else utils.reformat_to_n_channels_array(thickness, n_dims, n_channels) + ) atlas_res = atlas_res[0] - target_res = atlas_res if target_res is None else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + target_res = ( + atlas_res + if target_res is None + else utils.reformat_to_n_channels_array(target_res, n_dims)[0] + ) # get shapes - crop_shape, output_shape = get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_n) + crop_shape, output_shape = get_shapes( + labels_shape, output_shape, atlas_res, target_res, output_div_by_n + ) # define model inputs - labels_input = KL.Input(shape=labels_shape + [1], name='labels_input', dtype='int32') - means_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='means_input') - stds_input = KL.Input(shape=list(generation_labels.shape) + [n_channels], name='std_devs_input') + labels_input = KL.Input( + shape=labels_shape + [1], name="labels_input", dtype="int32" + ) + means_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="means_input" + ) + stds_input = KL.Input( + shape=list(generation_labels.shape) + [n_channels], name="std_devs_input" + ) list_inputs = [labels_input, means_input, stds_input] # deform labels - labels = layers.RandomSpatialDeformation(scaling_bounds=scaling_bounds, - rotation_bounds=rotation_bounds, - shearing_bounds=shearing_bounds, - translation_bounds=translation_bounds, - nonlin_std=nonlin_std, - nonlin_scale=nonlin_scale, - inter_method='nearest')(labels_input) + labels = layers.RandomSpatialDeformation( + scaling_bounds=scaling_bounds, + rotation_bounds=rotation_bounds, + shearing_bounds=shearing_bounds, + translation_bounds=translation_bounds, + nonlin_std=nonlin_std, + nonlin_scale=nonlin_scale, + inter_method="nearest", + )(labels_input) # cropping if crop_shape != labels_shape: @@ -176,58 +199,92 @@ def labels_to_image_model(labels_shape, # flipping if flipping: - assert aff is not None, 'aff should not be None if flipping is True' - labels = layers.RandomFlip(get_ras_axes(aff, n_dims)[0], True, generation_labels, n_neutral_labels)(labels) + assert aff is not None, "aff should not be None if flipping is True" + labels = layers.RandomFlip( + get_ras_axes(aff, n_dims)[0], True, generation_labels, n_neutral_labels + )(labels) # build synthetic image - image = layers.SampleConditionalGMM(generation_labels)([labels, means_input, stds_input]) + image = layers.SampleConditionalGMM(generation_labels)( + [labels, means_input, stds_input] + ) # apply bias field if bias_field_std > 0: image = layers.BiasFieldCorruption(bias_field_std, bias_scale, False)(image) # intensity augmentation - image = layers.IntensityAugmentation(clip=300, normalise=True, gamma_std=.5, separate_channels=True)(image) + image = layers.IntensityAugmentation( + clip=300, normalise=True, gamma_std=0.5, separate_channels=True + )(image) # loop over channels channels = list() - split = KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) if (n_channels > 1) else [image] + split = ( + KL.Lambda(lambda x: tf.split(x, [1] * n_channels, axis=-1))(image) + if (n_channels > 1) + else [image] + ) for i, channel in enumerate(split): if randomise_res: - max_res_iso = np.array(utils.reformat_to_list(max_res_iso, length=n_dims, dtype='float')) - max_res_aniso = np.array(utils.reformat_to_list(max_res_aniso, length=n_dims, dtype='float')) + max_res_iso = np.array( + utils.reformat_to_list(max_res_iso, length=n_dims, dtype="float") + ) + max_res_aniso = np.array( + utils.reformat_to_list(max_res_aniso, length=n_dims, dtype="float") + ) max_res = np.maximum(max_res_iso, max_res_aniso) - resolution, blur_res = layers.SampleResolution(atlas_res, max_res_iso, max_res_aniso)(means_input) - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, resolution, thickness=blur_res) - channel = layers.DynamicGaussianBlur(0.75 * max_res / np.array(atlas_res), 1.03)([channel, sigma]) - channel = layers.MimicAcquisition(atlas_res, atlas_res, output_shape, False)([channel, resolution]) + resolution, blur_res = layers.SampleResolution( + atlas_res, max_res_iso, max_res_aniso + )(means_input) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, resolution, thickness=blur_res + ) + channel = layers.DynamicGaussianBlur( + 0.75 * max_res / np.array(atlas_res), 1.03 + )([channel, sigma]) + channel = layers.MimicAcquisition( + atlas_res, atlas_res, output_shape, False + )([channel, resolution]) channels.append(channel) else: - sigma = l2i_et.blurring_sigma_for_downsampling(atlas_res, data_res[i], thickness=thickness[i]) + sigma = l2i_et.blurring_sigma_for_downsampling( + atlas_res, data_res[i], thickness=thickness[i] + ) channel = layers.GaussianBlur(sigma, 1.03)(channel) - resolution = KL.Lambda(lambda x: tf.convert_to_tensor(data_res[i], dtype='float32'))([]) - channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)([channel, resolution]) + resolution = KL.Lambda( + lambda x: tf.convert_to_tensor(data_res[i], dtype="float32") + )([]) + channel = layers.MimicAcquisition(atlas_res, data_res[i], output_shape)( + [channel, resolution] + ) channels.append(channel) # concatenate all channels back - image = KL.Lambda(lambda x: tf.concat(x, -1))(channels) if len(channels) > 1 else channels[0] + image = ( + KL.Lambda(lambda x: tf.concat(x, -1))(channels) + if len(channels) > 1 + else channels[0] + ) # compute image gradient if return_gradients: - image = layers.ImageGradients('sobel', True, name='image_gradients')(image) + image = layers.ImageGradients("sobel", True, name="image_gradients")(image) image = layers.IntensityAugmentation(clip=10, normalise=True)(image) # resample labels at target resolution if crop_shape != output_shape: - labels = l2i_et.resample_tensor(labels, output_shape, interp_method='nearest') + labels = l2i_et.resample_tensor(labels, output_shape, interp_method="nearest") # map generation labels to segmentation values - labels = layers.ConvertLabels(generation_labels, dest_values=output_labels, name='labels_out')(labels) + labels = layers.ConvertLabels( + generation_labels, dest_values=output_labels, name="labels_out" + )(labels) # build model (dummy layer enables to keep the labels when plugging this model to other models) - image = KL.Lambda(lambda x: x[0], name='image_out')([image, labels]) + image = KL.Lambda(lambda x: x[0], name="image_out")([image, labels]) brain_model = Model(inputs=list_inputs, outputs=[image, labels]) return brain_model @@ -248,25 +305,39 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ # output shape specified, need to get cropping shape, and resample shape if necessary if output_shape is not None: - output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype='int') + output_shape = utils.reformat_to_list(output_shape, length=n_dims, dtype="int") # make sure that output shape is smaller or equal to label shape if resample_factor is not None: - output_shape = [min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(int(labels_shape[i] * resample_factor[i]), output_shape[i]) + for i in range(n_dims) + ] else: - output_shape = [min(labels_shape[i], output_shape[i]) for i in range(n_dims)] + output_shape = [ + min(labels_shape[i], output_shape[i]) for i in range(n_dims) + ] # make sure output shape is divisible by output_div_by_n if output_div_by_n is not None: - tmp_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in output_shape] + tmp_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] if output_shape != tmp_shape: - print('output shape {0} not divisible by {1}, changed to {2}'.format(output_shape, output_div_by_n, - tmp_shape)) + print( + "output shape {0} not divisible by {1}, changed to {2}".format( + output_shape, output_div_by_n, tmp_shape + ) + ) output_shape = tmp_shape # get cropping and resample shape if resample_factor is not None: - cropping_shape = [int(np.around(output_shape[i]/resample_factor[i], 0)) for i in range(n_dims)] + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] else: cropping_shape = output_shape @@ -278,19 +349,32 @@ def get_shapes(labels_shape, output_shape, atlas_res, target_res, output_div_by_ # if resampling, get the potential output_shape and check if it is divisible by n if resample_factor is not None: - output_shape = [int(labels_shape[i] * resample_factor[i]) for i in range(n_dims)] - output_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in output_shape] - cropping_shape = [int(np.around(output_shape[i] / resample_factor[i], 0)) for i in range(n_dims)] + output_shape = [ + int(labels_shape[i] * resample_factor[i]) for i in range(n_dims) + ] + output_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in output_shape + ] + cropping_shape = [ + int(np.around(output_shape[i] / resample_factor[i], 0)) + for i in range(n_dims) + ] # if no resampling, simply check if image_shape is divisible by n else: - cropping_shape = [utils.find_closest_number_divisible_by_m(s, output_div_by_n) for s in labels_shape] + cropping_shape = [ + utils.find_closest_number_divisible_by_m(s, output_div_by_n) + for s in labels_shape + ] output_shape = cropping_shape # if no need to be divisible by n, simply take cropping_shape as image_shape, and build output_shape else: cropping_shape = labels_shape if resample_factor is not None: - output_shape = [int(cropping_shape[i] * resample_factor[i]) for i in range(n_dims)] + output_shape = [ + int(cropping_shape[i] * resample_factor[i]) for i in range(n_dims) + ] else: output_shape = cropping_shape diff --git a/nobrainer/processing/brain_generator.py b/nobrainer/processing/brain_generator.py index 2b7bfee9..aea63b78 100644 --- a/nobrainer/processing/brain_generator.py +++ b/nobrainer/processing/brain_generator.py @@ -13,52 +13,53 @@ License. """ - # python imports import numpy as np # project imports from nobrainer.ext.SynthSeg.model_inputs import build_model_inputs -from nobrainer.models.labels_to_image_model import labels_to_image_model # third-party imports -from nobrainer.ext.lab2im import utils, edit_volumes +from nobrainer.ext.lab2im import edit_volumes, utils +from nobrainer.models.labels_to_image_model import labels_to_image_model class BrainGenerator: - def __init__(self, - labels_dir, - generation_labels=None, - n_neutral_labels=None, - output_labels=None, - subjects_prob=None, - batchsize=1, - n_channels=1, - target_res=None, - output_shape=None, - output_div_by_n=None, - prior_distributions='uniform', - generation_classes=None, - prior_means=None, - prior_stds=None, - use_specific_stats_for_channel=False, - mix_prior_and_random=False, - flipping=True, - scaling_bounds=.2, - rotation_bounds=15, - shearing_bounds=.012, - translation_bounds=False, - nonlin_std=4., - nonlin_scale=.04, - randomise_res=True, - max_res_iso=4., - max_res_aniso=8., - data_res=None, - thickness=None, - bias_field_std=.7, - bias_scale=.025, - return_gradients=False): + def __init__( + self, + labels_dir, + generation_labels=None, + n_neutral_labels=None, + output_labels=None, + subjects_prob=None, + batchsize=1, + n_channels=1, + target_res=None, + output_shape=None, + output_div_by_n=None, + prior_distributions="uniform", + generation_classes=None, + prior_means=None, + prior_stds=None, + use_specific_stats_for_channel=False, + mix_prior_and_random=False, + flipping=True, + scaling_bounds=0.2, + rotation_bounds=15, + shearing_bounds=0.012, + translation_bounds=False, + nonlin_std=4.0, + nonlin_scale=0.04, + randomise_res=True, + max_res_iso=4.0, + max_res_aniso=8.0, + data_res=None, + thickness=None, + bias_field_std=0.7, + bias_scale=0.025, + return_gradients=False, + ): """ This class is wrapper around the labels_to_image_model model. It contains the GPU model that generates images from labels maps, and a python generator that supplies the input data for this model. @@ -195,16 +196,21 @@ def __init__(self, # prepare data files self.labels_paths = utils.list_images_in_folder(labels_dir) if subjects_prob is not None: - self.subjects_prob = np.array(utils.reformat_to_list(subjects_prob, load_as_numpy=True), dtype='float32') - assert len(self.subjects_prob) == len(self.labels_paths), \ - 'subjects_prob should have the same length as labels_path, ' \ - 'had {} and {}'.format(len(self.subjects_prob), len(self.labels_paths)) + self.subjects_prob = np.array( + utils.reformat_to_list(subjects_prob, load_as_numpy=True), + dtype="float32", + ) + assert len(self.subjects_prob) == len(self.labels_paths), ( + "subjects_prob should have the same length as labels_path, " + "had {} and {}".format(len(self.subjects_prob), len(self.labels_paths)) + ) else: self.subjects_prob = None # generation parameters - self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = \ + self.labels_shape, self.aff, self.n_dims, _, self.header, self.atlas_res = ( utils.get_volume_info(self.labels_paths[0], aff_ref=np.eye(4)) + ) self.n_channels = n_channels if generation_labels is not None: self.generation_labels = utils.load_array_if_path(generation_labels) @@ -228,11 +234,13 @@ def __init__(self, self.prior_distributions = prior_distributions if generation_classes is not None: self.generation_classes = utils.load_array_if_path(generation_classes) - assert self.generation_classes.shape == self.generation_labels.shape, \ - 'if provided, generation_classes should have the same shape as generation_labels' + assert ( + self.generation_classes.shape == self.generation_labels.shape + ), "if provided, generation_classes should have the same shape as generation_labels" unique_classes = np.unique(self.generation_classes) - assert np.array_equal(unique_classes, np.arange(np.max(unique_classes)+1)), \ - 'generation_classes should a linear range between 0 and its maximum value.' + assert np.array_equal( + unique_classes, np.arange(np.max(unique_classes) + 1) + ), "generation_classes should a linear range between 0 and its maximum value." else: self.generation_classes = np.arange(self.generation_labels.shape[0]) self.prior_means = utils.load_array_if_path(prior_means) @@ -251,8 +259,9 @@ def __init__(self, self.max_res_iso = max_res_iso self.max_res_aniso = max_res_aniso self.data_res = utils.load_array_if_path(data_res) - assert not (self.randomise_res & (self.data_res is not None)), \ - 'randomise_res and data_res cannot be provided at the same time' + assert not ( + self.randomise_res & (self.data_res is not None) + ), "randomise_res and data_res cannot be provided at the same time" self.thickness = utils.load_array_if_path(thickness) # bias field parameters self.bias_field_std = bias_field_std @@ -260,57 +269,65 @@ def __init__(self, self.return_gradients = return_gradients # build transformation model - self.labels_to_image_model, self.model_output_shape = self._build_labels_to_image_model() + self.labels_to_image_model, self.model_output_shape = ( + self._build_labels_to_image_model() + ) # build generator for model inputs - self.model_inputs_generator = self._build_model_inputs_generator(mix_prior_and_random) + self.model_inputs_generator = self._build_model_inputs_generator( + mix_prior_and_random + ) # build brain generator self.brain_generator = self._build_brain_generator() def _build_labels_to_image_model(self): # build_model - lab_to_im_model = labels_to_image_model(labels_shape=self.labels_shape, - n_channels=self.n_channels, - generation_labels=self.generation_labels, - output_labels=self.output_labels, - n_neutral_labels=self.n_neutral_labels, - atlas_res=self.atlas_res, - target_res=self.target_res, - output_shape=self.output_shape, - output_div_by_n=self.output_div_by_n, - flipping=self.flipping, - aff=np.eye(4), - scaling_bounds=self.scaling_bounds, - rotation_bounds=self.rotation_bounds, - shearing_bounds=self.shearing_bounds, - translation_bounds=self.translation_bounds, - nonlin_std=self.nonlin_std, - nonlin_scale=self.nonlin_scale, - randomise_res=self.randomise_res, - max_res_iso=self.max_res_iso, - max_res_aniso=self.max_res_aniso, - data_res=self.data_res, - thickness=self.thickness, - bias_field_std=self.bias_field_std, - bias_scale=self.bias_scale, - return_gradients=self.return_gradients) + lab_to_im_model = labels_to_image_model( + labels_shape=self.labels_shape, + n_channels=self.n_channels, + generation_labels=self.generation_labels, + output_labels=self.output_labels, + n_neutral_labels=self.n_neutral_labels, + atlas_res=self.atlas_res, + target_res=self.target_res, + output_shape=self.output_shape, + output_div_by_n=self.output_div_by_n, + flipping=self.flipping, + aff=np.eye(4), + scaling_bounds=self.scaling_bounds, + rotation_bounds=self.rotation_bounds, + shearing_bounds=self.shearing_bounds, + translation_bounds=self.translation_bounds, + nonlin_std=self.nonlin_std, + nonlin_scale=self.nonlin_scale, + randomise_res=self.randomise_res, + max_res_iso=self.max_res_iso, + max_res_aniso=self.max_res_aniso, + data_res=self.data_res, + thickness=self.thickness, + bias_field_std=self.bias_field_std, + bias_scale=self.bias_scale, + return_gradients=self.return_gradients, + ) out_shape = lab_to_im_model.output[0].get_shape().as_list()[1:] return lab_to_im_model, out_shape def _build_model_inputs_generator(self, mix_prior_and_random): # build model's inputs generator - model_inputs_generator = build_model_inputs(path_label_maps=self.labels_paths, - n_labels=len(self.generation_labels), - batchsize=self.batchsize, - n_channels=self.n_channels, - subjects_prob=self.subjects_prob, - generation_classes=self.generation_classes, - prior_means=self.prior_means, - prior_stds=self.prior_stds, - prior_distributions=self.prior_distributions, - use_specific_stats_for_channel=self.use_specific_stats_for_channel, - mix_prior_and_random=mix_prior_and_random) + model_inputs_generator = build_model_inputs( + path_label_maps=self.labels_paths, + n_labels=len(self.generation_labels), + batchsize=self.batchsize, + n_channels=self.n_channels, + subjects_prob=self.subjects_prob, + generation_classes=self.generation_classes, + prior_means=self.prior_means, + prior_stds=self.prior_stds, + prior_distributions=self.prior_distributions, + use_specific_stats_for_channel=self.use_specific_stats_for_channel, + mix_prior_and_random=mix_prior_and_random, + ) return model_inputs_generator def _build_brain_generator(self): @@ -326,10 +343,16 @@ def generate_brain(self): list_images = list() list_labels = list() for i in range(self.batchsize): - list_images.append(edit_volumes.align_volume_to_ref(image[i], np.eye(4), - aff_ref=self.aff, n_dims=self.n_dims)) - list_labels.append(edit_volumes.align_volume_to_ref(labels[i], np.eye(4), - aff_ref=self.aff, n_dims=self.n_dims)) + list_images.append( + edit_volumes.align_volume_to_ref( + image[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) + list_labels.append( + edit_volumes.align_volume_to_ref( + labels[i], np.eye(4), aff_ref=self.aff, n_dims=self.n_dims + ) + ) image = np.squeeze(np.stack(list_images, axis=0)) labels = np.squeeze(np.stack(list_labels, axis=0)) return image, labels diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index 58d0e03e..11aadab7 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -50,7 +50,7 @@ def fit( metrics=metrics.dice, callbacks=None, verbose=1, - initial_epoch=0 + initial_epoch=0, ): """Train a segmentation model""" # TODO: check validity of datasets @@ -116,7 +116,7 @@ def _compile(): ), callbacks=callbacks, verbose=verbose, - initial_epoch=initial_epoch + initial_epoch=initial_epoch, ) return self diff --git a/nobrainer/tfrecord.py b/nobrainer/tfrecord.py index 0f3bbd95..a4701fde 100644 --- a/nobrainer/tfrecord.py +++ b/nobrainer/tfrecord.py @@ -58,7 +58,9 @@ def write( verbose: int, if 1, print progress bar. If 0, print nothing. """ n_examples = len(features_labels) - shards = np.array_split(features_labels, np.arange(examples_per_shard, n_examples, examples_per_shard)) + shards = np.array_split( + features_labels, np.arange(examples_per_shard, n_examples, examples_per_shard) + ) # Test that the `filename_template` has a `shard` formatting key. try: @@ -77,7 +79,9 @@ def write( # This is the object that returns a protocol buffer string of the feature and label # on each iteration. It is pickle-able, unlike a generator. proto_iterators = [ - _ProtoIterator(s, to_ras=to_ras, multi_resolution=multi_resolution, resolutions=resolutions) + _ProtoIterator( + s, to_ras=to_ras, multi_resolution=multi_resolution, resolutions=resolutions + ) for s in shards ] # Set up positional arguments for the core writer function.