diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 3f02fae6b8..4108820bec 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -23,6 +23,7 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after @@ -39,8 +40,16 @@ class DiceLoss(_Loss): The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of the inter-over-union calculation to smooth results respectively, these values should be small. - The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric - Medical Image Segmentation, 3DV, 2016. + The original papers: + + Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation. 3DV 2016. + + Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with + Soft Labels. NeurIPS 2023. + + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. """ @@ -58,6 +67,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, ) -> None: """ Args: @@ -89,6 +99,8 @@ def __init__( of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -114,6 +126,7 @@ def __init__( weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -174,21 +187,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, dim=reduce_axis) - - if self.squared_pred: - ground_o = torch.sum(target**2, dim=reduce_axis) - pred_o = torch.sum(input**2, dim=reduce_axis) - else: - ground_o = torch.sum(target, dim=reduce_axis) - pred_o = torch.sum(input, dim=reduce_axis) - - denominator = ground_o + pred_o - - if self.jaccard: - denominator = 2.0 * (denominator - intersection) + ord = 2 if self.squared_pred else 1 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label) + if not self.jaccard: + fp *= 0.5 + fn *= 0.5 + numerator = 2 * tp + self.smooth_nr + denominator = 2 * (tp + fp + fn) + self.smooth_dr - f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) + f: torch.Tensor = 1 - numerator / denominator num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -272,6 +279,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -295,6 +303,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. If True, the class-weighted intersection and union areas are first summed across the batches. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -319,6 +329,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def w_func(self, grnd): if self.w_type == str(Weight.SIMPLE): @@ -370,13 +381,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, reduce_axis) - ground_o = torch.sum(target, reduce_axis) - pred_o = torch.sum(input, reduce_axis) - - denominator = ground_o + pred_o + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label) + fp *= 0.5 + fn *= 0.5 + denominator = 2 * (tp + fp + fn) + ground_o = torch.sum(target, reduce_axis) w = self.w_func(ground_o.float()) infs = torch.isinf(w) if self.batch: @@ -388,7 +399,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = w + infs * max_values final_reduce_dim = 0 if self.batch else 1 - numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 4f22bf84b4..154f34c526 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -17,6 +17,7 @@ import torch from torch.nn.modules.loss import _Loss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import LossReduction @@ -28,6 +29,9 @@ class TverskyLoss(_Loss): Sadegh et al. (2017) Tversky loss function for image segmentation using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721) + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. + Adapted from: https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631 @@ -46,6 +50,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -70,6 +75,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -93,6 +100,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -134,20 +142,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - p0 = input - p1 = 1 - p0 - g0 = target - g1 = 1 - g0 - # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - tp = torch.sum(p0 * g0, reduce_axis) - fp = self.alpha * torch.sum(p0 * g1, reduce_axis) - fn = self.beta * torch.sum(p1 * g0, reduce_axis) + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False) + fp *= self.alpha + fn *= self.beta numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr diff --git a/monai/losses/utils.py b/monai/losses/utils.py new file mode 100644 index 0000000000..782fd9c9c2 --- /dev/null +++ b/monai/losses/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.linalg as LA + + +def compute_tp_fp_fn( + input: torch.Tensor, + target: torch.Tensor, + reduce_axis: list[int], + ord: int, + soft_label: bool, + decoupled: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + reduce_axis: the axis to be reduced. + ord: the order of the vector norm. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. + decoupled: whether the input and the target should be decoupled when computing fp and fn. + Only for the original implementation when soft_label is False. + + Adapted from: + https://github.com/zifuwanggg/JDTLosses + """ + + # the original implementation that is erroneous with soft labels + if ord == 1 and not soft_label: + tp = torch.sum(input * target, dim=reduce_axis) + # the original implementation of Dice and Jaccard loss + if decoupled: + fp = torch.sum(input, dim=reduce_axis) - tp + fn = torch.sum(target, dim=reduce_axis) - tp + # the original implementation of Tversky loss + else: + fp = torch.sum(input * (1 - target), dim=reduce_axis) + fn = torch.sum((1 - input) * target, dim=reduce_axis) + # the new implementation that is correct with soft labels + # and it is identical to the original implementation with hard labels + else: + pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis) + ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis) + difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis) + + if ord > 1: + pred_o = torch.pow(pred_o, exponent=ord) + ground_o = torch.pow(ground_o, exponent=ord) + difference = torch.pow(difference, exponent=ord) + + tp = (pred_o + ground_o - difference) / 2 + fp = pred_o - tp + fn = ground_o - tp + + return tp, fp, fn diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 14aa6ec241..cea6ccf113 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 5738f4a089..9706c2e746 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416597, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307748, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0}, { diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 0365503ea2..73a841a55d 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, {