Skip to content

Commit

Permalink
Implementation of intensity clipping transform: bot hard and soft app…
Browse files Browse the repository at this point in the history
…roaches (Project-MONAI#7535)

Fixes Issue Project-MONAI#7512.

### Description

Addition of a transformation allowing values above or below a certain
percentile to be clipped.
Clipping can be hard or soft.
With soft clipping, the function remains derivable and the order of the
values is respected, with smoother corners.

The soft clipping function is based on this medium article
https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291

It's important to note that I've chosen to switch from Nones values to
percentiles to take account of the fact that soft clipping can be
one-sided or two-sided.
In fact, providing percentiles of 100 or 0 doesn't change anything in
the case of hard clipping, but it does in the case of soft clipping
because the function is smoothed. Hence the interest in introducing the
possibility of putting None to avoid smoothing the function on one side
or the other.

To implement this we had to define a `softplus` function in
`monai.transforms.utils_pytorch_numpy_unification.py`. One of the
problems is that `np.logaddexp` do not exactly yields same outputs as
`torch.logaddexp`. I've left it as is and lowered the tolerance of the
tests slightly, but it's possible to force the conversion to numpy and
then switch back to torch to ensure better unification between the
frameworks.

I've also added the `soft_clip` function in `monai.transforms.utils.py`
with the associated unit tests to ensure that the transformation works
properly.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Lucas Robinet <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 5, 2024
1 parent 625967c commit c0b9cc0
Show file tree
Hide file tree
Showing 9 changed files with 763 additions and 1 deletion.
12 changes: 12 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ Intensity
:members:
:special-members: __call__

`ClipIntensityPercentiles`
""""""""""""""""""""""""""
.. autoclass:: ClipIntensityPercentiles
:members:
:special-members: __call__

`RandScaleIntensity`
""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensity.png
Expand Down Expand Up @@ -1405,6 +1411,12 @@ Intensity (Dict)
:members:
:special-members: __call__

`ClipIntensityPercentilesd`
"""""""""""""""""""""""""""
.. autoclass:: ClipIntensityPercentilesd
:members:
:special-members: __call__

`RandScaleIntensityd`
"""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensityd.png
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from .croppad.functional import crop_func, crop_or_pad_nd, pad_func, pad_nd
from .intensity.array import (
AdjustContrast,
ClipIntensityPercentiles,
ComputeHoVerMaps,
DetectEnvelope,
ForegroundMask,
Expand Down Expand Up @@ -135,6 +136,9 @@
AdjustContrastd,
AdjustContrastD,
AdjustContrastDict,
ClipIntensityPercentilesd,
ClipIntensityPercentilesD,
ClipIntensityPercentilesDict,
ComputeHoVerMapsd,
ComputeHoVerMapsD,
ComputeHoVerMapsDict,
Expand Down
148 changes: 147 additions & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array, soft_clip
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
from monai.utils.enums import TransformBackends
from monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple
Expand All @@ -54,6 +54,7 @@
"NormalizeIntensity",
"ThresholdIntensity",
"ScaleIntensityRange",
"ClipIntensityPercentiles",
"AdjustContrast",
"RandAdjustContrast",
"ScaleIntensityRangePercentiles",
Expand Down Expand Up @@ -1007,6 +1008,151 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return ret


class ClipIntensityPercentiles(Transform):
"""
Apply clip based on the intensity distribution of input image.
If `sharpness_factor` is provided, the intensity values will be soft clipped according to
f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291
Soft clipping preserves the order of the values and maintains the gradient everywhere.
For example:
.. code-block:: python
:emphasize-lines: 11, 22
image = torch.Tensor(
[[[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]]])
# Hard clipping from lower and upper image intensity percentiles
hard_clipper = ClipIntensityPercentiles(30, 70)
print(hard_clipper(image))
metatensor([[[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.],
[2., 2., 3., 4., 4.]]])
# Soft clipping from lower and upper image intensity percentiles
soft_clipper = ClipIntensityPercentiles(30, 70, 10.)
print(soft_clipper(image))
metatensor([[[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
[2.0000, 2.0693, 3.0000, 3.9307, 4.0000]]])
See Also:
- :py:class:`monai.transforms.ScaleIntensityRangePercentiles`
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
lower: float | None,
upper: float | None,
sharpness_factor: float | None = None,
channel_wise: bool = False,
return_clipping_values: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
"""
Args:
lower: lower intensity percentile. In the case of hard clipping, None will have the same effect as 0 by
not clipping the lowest input values. However, in the case of soft clipping, None and zero will have
two different effects: None will not apply clipping to low values, whereas zero will still transform
the lower values according to the soft clipping transformation. Please check for more details:
https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291.
upper: upper intensity percentile. The same as for lower, but this time with the highest values. If we
are looking to perform soft clipping, if None then there will be no effect on this side whereas if set
to 100, the values will be passed via the corresponding clipping equation.
sharpness_factor: if not None, the intensity values will be soft clipped according to
f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)).
defaults to None.
channel_wise: if True, compute intensity percentile and normalize every channel separately.
default to False.
return_clipping_values: whether to return the calculated percentiles in tensor meta information.
If soft clipping and requested percentile is None, return None as the corresponding clipping
values in meta information. Clipping values are stored in a list with each element corresponding
to a channel if channel_wise is set to True. defaults to False.
dtype: output data type, if None, same as input image. defaults to float32.
"""
if lower is None and upper is None:
raise ValueError("lower or upper percentiles must be provided")
if lower is not None and (lower < 0.0 or lower > 100.0):
raise ValueError("Percentiles must be in the range [0, 100]")
if upper is not None and (upper < 0.0 or upper > 100.0):
raise ValueError("Percentiles must be in the range [0, 100]")
if upper is not None and lower is not None and upper < lower:
raise ValueError("upper must be greater than or equal to lower")
if sharpness_factor is not None and sharpness_factor <= 0:
raise ValueError("sharpness_factor must be greater than 0")

self.lower = lower
self.upper = upper
self.sharpness_factor = sharpness_factor
self.channel_wise = channel_wise
if return_clipping_values:
self.clipping_values: list[tuple[float | None, float | None]] = []
self.return_clipping_values = return_clipping_values
self.dtype = dtype

def _clip(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
if self.sharpness_factor is not None:
lower_percentile = percentile(img, self.lower) if self.lower is not None else None
upper_percentile = percentile(img, self.upper) if self.upper is not None else None
img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype)
else:
lower_percentile = percentile(img, self.lower) if self.lower is not None else percentile(img, 0)
upper_percentile = percentile(img, self.upper) if self.upper is not None else percentile(img, 100)
img = clip(img, lower_percentile, upper_percentile)

if self.return_clipping_values:
self.clipping_values.append(
(
(
lower_percentile
if lower_percentile is None
else lower_percentile.item() if hasattr(lower_percentile, "item") else lower_percentile
),
(
upper_percentile
if upper_percentile is None
else upper_percentile.item() if hasattr(upper_percentile, "item") else upper_percentile
),
)
)
img = convert_to_tensor(img, track_meta=False)
return img

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
if self.channel_wise:
img_t = torch.stack([self._clip(img=d) for d in img_t]) # type: ignore
else:
img_t = self._clip(img=img_t)

img = convert_to_dst_type(img_t, dst=img)[0]
if self.return_clipping_values:
img.meta["clipping_values"] = self.clipping_values # type: ignore

return img


class AdjustContrast(Transform):
"""
Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as::
Expand Down
35 changes: 35 additions & 0 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.data.meta_obj import get_track_meta
from monai.transforms.intensity.array import (
AdjustContrast,
ClipIntensityPercentiles,
ComputeHoVerMaps,
ForegroundMask,
GaussianSharpen,
Expand Down Expand Up @@ -77,6 +78,7 @@
"NormalizeIntensityd",
"ThresholdIntensityd",
"ScaleIntensityRanged",
"ClipIntensityPercentilesd",
"AdjustContrastd",
"RandAdjustContrastd",
"ScaleIntensityRangePercentilesd",
Expand Down Expand Up @@ -122,6 +124,8 @@
"ThresholdIntensityDict",
"ScaleIntensityRangeD",
"ScaleIntensityRangeDict",
"ClipIntensityPercentilesD",
"ClipIntensityPercentilesDict",
"AdjustContrastD",
"AdjustContrastDict",
"RandAdjustContrastD",
Expand Down Expand Up @@ -886,6 +890,36 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ClipIntensityPercentilesd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ClipIntensityPercentiles`.
Clip the intensity values of input image to a specific range based on the intensity distribution of the input.
If `sharpness_factor` is provided, the intensity values will be soft clipped according to
f(x) = x + (1/sharpness_factor) * softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
"""

def __init__(
self,
keys: KeysCollection,
lower: float | None,
upper: float | None,
sharpness_factor: float | None = None,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.scaler = ClipIntensityPercentiles(
lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype
)

def __call__(self, data: dict) -> dict:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
return d


class AdjustContrastd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.AdjustContrast`.
Expand Down Expand Up @@ -1929,6 +1963,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
NormalizeIntensityD = NormalizeIntensityDict = NormalizeIntensityd
ThresholdIntensityD = ThresholdIntensityDict = ThresholdIntensityd
ScaleIntensityRangeD = ScaleIntensityRangeDict = ScaleIntensityRanged
ClipIntensityPercentilesD = ClipIntensityPercentilesDict = ClipIntensityPercentilesd
AdjustContrastD = AdjustContrastDict = AdjustContrastd
RandAdjustContrastD = RandAdjustContrastDict = RandAdjustContrastd
ScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd
Expand Down
37 changes: 37 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
nonzero,
ravel,
searchsorted,
softplus,
unique,
unravel_index,
where,
Expand Down Expand Up @@ -131,9 +132,45 @@
"resolves_modes",
"has_status_keys",
"distance_transform_edt",
"soft_clip",
]


def soft_clip(
arr: NdarrayOrTensor,
sharpness_factor: float = 1.0,
minv: NdarrayOrTensor | float | int | None = None,
maxv: NdarrayOrTensor | float | int | None = None,
dtype: DtypeLike | torch.dtype = np.float32,
) -> NdarrayOrTensor:
"""
Apply soft clip to the input array or tensor.
The intensity values will be soft clipped according to
f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291
To perform one-sided clipping, set either minv or maxv to None.
Args:
arr: input array to clip.
sharpness_factor: the sharpness of the soft clip function, default to 1.
minv: minimum value of target clipped array.
maxv: maximum value of target clipped array.
dtype: if not None, convert input array to dtype before computation.
"""

if dtype is not None:
arr, *_ = convert_data_type(arr, dtype=dtype)

v = arr
if minv is not None:
v = v + softplus(-sharpness_factor * (arr - minv)) / sharpness_factor
if maxv is not None:
v = v - softplus(sharpness_factor * (arr - maxv)) / sharpness_factor

return v


def rand_choice(prob: float = 0.5) -> bool:
"""
Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance.
Expand Down
15 changes: 15 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,24 @@
"median",
"mean",
"std",
"softplus",
]


def softplus(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""stable softplus through `np.logaddexp` with equivalent implementation for torch.
Args:
x: array/tensor.
Returns:
Softplus of the input.
"""
if isinstance(x, np.ndarray):
return np.logaddexp(np.zeros_like(x), x)
return torch.logaddexp(torch.zeros_like(x), x)


def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool:
"""`np.allclose` with equivalent implementation for torch."""
b, *_ = convert_to_dst_type(b, a, wrap_sequence=True)
Expand Down
Loading

0 comments on commit c0b9cc0

Please sign in to comment.