diff --git a/CHANGELOG.md b/CHANGELOG.md index d9e4a33e55e..38d71e5ba5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805)) --- diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index ca4178d35ee..d4930d905ab 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -182,14 +182,17 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] """Update state with predictions and targets.""" _input_validator(preds, target, ignore_score=True) - for p, t in zip(preds, target): - det_boxes = self._get_safe_item_values(p["boxes"]) - gt_boxes = self._get_safe_item_values(t["boxes"]) - self.groundtruth_labels.append(t["labels"]) + for p_i, t_i in zip(preds, target): + det_boxes = self._get_safe_item_values(p_i["boxes"]) + gt_boxes = self._get_safe_item_values(t_i["boxes"]) + self.groundtruth_labels.append(t_i["labels"]) iou_matrix = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val) # N x M if self.respect_labels: - label_eq = p["labels"].unsqueeze(1) == t["labels"].unsqueeze(0) # N x M + if det_boxes.numel() > 0 and gt_boxes.numel() > 0: + label_eq = p_i["labels"].unsqueeze(1) == t_i["labels"].unsqueeze(0) # N x M + else: + label_eq = torch.eye(iou_matrix.shape[0], dtype=bool, device=iou_matrix.device) # type: ignore[call-overload] iou_matrix[~label_eq] = self._invalid_val self.iou_matrix.append(iou_matrix) diff --git a/src/torchmetrics/functional/detection/ciou.py b/src/torchmetrics/functional/detection/ciou.py index 650651b2e4f..3df249f020d 100644 --- a/src/torchmetrics/functional/detection/ciou.py +++ b/src/torchmetrics/functional/detection/ciou.py @@ -31,6 +31,11 @@ def _ciou_update( from torchvision.ops import complete_box_iou + if preds.numel() == 0: # if no boxes are predicted + return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32) + if target.numel() == 0: # if no boxes are true + return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32) + iou = complete_box_iou(preds, target) if iou_threshold is not None: iou[iou < iou_threshold] = replacement_val diff --git a/src/torchmetrics/functional/detection/diou.py b/src/torchmetrics/functional/detection/diou.py index 7a9a3d907a9..3d71843ac62 100644 --- a/src/torchmetrics/functional/detection/diou.py +++ b/src/torchmetrics/functional/detection/diou.py @@ -31,6 +31,11 @@ def _diou_update( from torchvision.ops import distance_box_iou + if preds.numel() == 0: # if no boxes are predicted + return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32) + if target.numel() == 0: # if no boxes are true + return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32) + iou = distance_box_iou(preds, target) if iou_threshold is not None: iou[iou < iou_threshold] = replacement_val diff --git a/src/torchmetrics/functional/detection/giou.py b/src/torchmetrics/functional/detection/giou.py index feae12d3011..c3e467e45ff 100644 --- a/src/torchmetrics/functional/detection/giou.py +++ b/src/torchmetrics/functional/detection/giou.py @@ -31,6 +31,11 @@ def _giou_update( from torchvision.ops import generalized_box_iou + if preds.numel() == 0: # if no boxes are predicted + return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32) + if target.numel() == 0: # if no boxes are true + return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32) + iou = generalized_box_iou(preds, target) if iou_threshold is not None: iou[iou < iou_threshold] = replacement_val diff --git a/src/torchmetrics/functional/detection/iou.py b/src/torchmetrics/functional/detection/iou.py index 249b30dd2d9..62873c86f58 100644 --- a/src/torchmetrics/functional/detection/iou.py +++ b/src/torchmetrics/functional/detection/iou.py @@ -32,6 +32,11 @@ def _iou_update( from torchvision.ops import box_iou + if preds.numel() == 0: # if no boxes are predicted + return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32) + if target.numel() == 0: # if no boxes are true + return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32) + iou = box_iou(preds, target) if iou_threshold is not None: iou[iou < iou_threshold] = replacement_val diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index e76ce966474..88a6408c536 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -353,6 +353,43 @@ def test_corner_case_only_one_empty_prediction(self, class_metric, functional_me for val in res.values(): assert val == torch.tensor(0.0) + def test_empty_preds_and_target(self, class_metric, functional_metric, reference_metric): + """Check that for either empty preds and targets that the metric returns 0 in these cases before averaging.""" + x = [ + { + "boxes": torch.empty(size=(0, 4), dtype=torch.float32), + "labels": torch.tensor([], dtype=torch.long), + }, + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + }, + ] + + y = [ + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + "scores": torch.FloatTensor([0.9, 0.8]), + }, + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + "scores": torch.FloatTensor([0.9, 0.8]), + }, + ] + metric = class_metric() + metric.update(x, y) + res = metric.compute() + for val in res.values(): + assert val == torch.tensor(0.5) + + metric = class_metric() + metric.update(y, x) + res = metric.compute() + for val in res.values(): + assert val == torch.tensor(0.5) + def test_corner_case(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921."""