diff --git a/docs/usage.md b/docs/usage.md
index 518bda7..5aa41af 100644
--- a/docs/usage.md
+++ b/docs/usage.md
@@ -365,7 +365,7 @@ from od_metrics import iou
y_true = [[25, 16, 38, 56], [129, 123, 41, 62]]
y_pred = [[25, 27, 37, 54], [119, 111, 40, 67], [124, 9, 49, 67]]
-result = iou(y_true, y_pred)
+result = iou(y_true, y_pred, box_format="xywh")
print(result)
"""
array([[0.67655425, 0. ],
diff --git a/samples/samples.ipynb b/samples/samples.ipynb
index 16b5d9f..db76155 100644
--- a/samples/samples.ipynb
+++ b/samples/samples.ipynb
@@ -340,7 +340,7 @@
"y_true = [[25, 16, 38, 56], [129, 123, 41, 62]]\n",
"y_pred = [[25, 27, 37, 54], [119, 111, 40, 67], [124, 9, 49, 67]]\n",
"\n",
- "result = iou(y_true, y_pred)\n",
+ "result = iou(y_true, y_pred, box_format=\"xywh\")\n",
"result"
]
}
@@ -361,7 +361,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.18"
+ "version": "3.11.5"
}
},
"nbformat": 4,
diff --git a/src/od_metrics/od_metrics.py b/src/od_metrics/od_metrics.py
index 0aeaac4..ad9db49 100644
--- a/src/od_metrics/od_metrics.py
+++ b/src/od_metrics/od_metrics.py
@@ -16,7 +16,7 @@
import numpy as np
from .constants import DEFAULT_COCO, _STANDARD_OUTPUT
-from .utils import get_indexes, get_suffix, _Missing
+from .utils import get_indexes, get_suffix, _Missing, to_xyxy
from .validators import ConstructorModel, ComputeModel, MeanModel
@@ -483,7 +483,8 @@ def _compute_iou(
ious = iou(
y_true=y_true_boxes,
y_pred=y_pred_boxes,
- iscrowd=iscrowd
+ iscrowd=iscrowd,
+ box_format="xywh",
)
return ious
@@ -1050,13 +1051,11 @@ def iou(
y_true: np.ndarray | list,
y_pred: np.ndarray | list,
iscrowd: np.ndarray | list[bool] | list[int] | None = None,
+ box_format: Literal["xyxy", "xywh", "cxcywh"] = "xywh",
) -> np.ndarray:
"""
Calculate IoU between bounding boxes.
- Single bounding boxes must be in `"xywh"` format, i.e.
- [xmin, ymin, width, height]
-
The standard iou of a ground truth `y_true` and detected
`y_pred` object is:
@@ -1096,6 +1095,23 @@ def iou(
Whether `y_true` are crowd regions.
If `None`, it will be set to `False` for all `y_true`.
The default is `None`.
+ box_format: Literal["xyxy", "xywh", "cxcywh"], optional
+ Bounding box format.
+ Supported formats are:
+ - `"xyxy"`: boxes are represented via corners,
+ x1, y1 being top left and x2, y2
+ being bottom right.
+ - `"xywh"`: boxes are represented via corner,
+ width and height, x1, y2 being top
+ left, w, h being width and height.
+ This is the default format; all
+ input formats will be converted
+ to this.
+ - `"cxcywh"`: boxes are represented via centre,
+ width and height, cx, cy being
+ center of box, w, h being width
+ and height.
+ The default is `"xywh"`.
Returns
-------
@@ -1111,20 +1127,14 @@ def iou(
"length.")
else:
iscrowd = [False]*len(y_true)
- # To np.ndarray
- if not isinstance(y_pred, np.ndarray):
- y_pred = np.array(y_pred)
- if not isinstance(y_true, np.ndarray):
- y_true = np.array(y_true)
+ # To np.ndarray and xyxy box format
+ y_true = np.array([to_xyxy(bbox_, box_format) for bbox_ in y_true])
+ y_pred = np.array([to_xyxy(bbox_, box_format) for bbox_ in y_pred])
# pylint: disable-next=W0632
- xmin1, ymin1, width1, height1 = np.hsplit(y_pred, 4)
+ xmin1, ymin1, xmax1, ymax1 = np.hsplit(y_pred, 4)
# pylint: disable-next=W0632
- xmin2, ymin2, width2, height2 = np.hsplit(y_true, 4)
- xmax1 = xmin1 + width1
- xmax2 = xmin2 + width2
- ymax1 = ymin1 + height1
- ymax2 = ymin2 + height2
+ xmin2, ymin2, xmax2, ymax2 = np.hsplit(y_true, 4)
# Intersection
xmin_i = np.maximum(xmin1.T, xmin2).T
diff --git a/src/od_metrics/utils.py b/src/od_metrics/utils.py
index 6882b3a..8e0aea3 100644
--- a/src/od_metrics/utils.py
+++ b/src/od_metrics/utils.py
@@ -6,6 +6,8 @@
"_Missing",
"get_indexes",
"get_suffix",
+ "to_xywh",
+ "to_xyxy",
]
from typing import Literal
@@ -124,6 +126,87 @@ def to_xywh(
)
+def xywh_xyxy(bbox: list[float]) -> list[float]:
+ """
+ Change bounding box format from `xywh` to `xyxy`.
+
+ Parameters
+ ----------
+ bbox : list[float]
+ Input bounding box.
+
+ Returns
+ -------
+ list[float]
+ Bounding box in `"xyxy"` format.
+ """
+ return [
+ bbox[0],
+ bbox[1],
+ bbox[0] + bbox[2],
+ bbox[1] + bbox[3]
+ ]
+
+
+def cxcywh_xyxy(bbox: list[float]) -> list[float]:
+ """
+ Change bounding box format from `cxcywh` to `xyxy`.
+
+ Parameters
+ ----------
+ bbox : list[float]
+ Input bounding box.
+
+ Returns
+ -------
+ list[float]
+ Bounding box in `"xyxy"` format.
+ """
+ return [
+ bbox[0] - bbox[2] / 2,
+ bbox[1] - bbox[3] / 2,
+ bbox[0] + bbox[2] / 2,
+ bbox[1] + bbox[3] / 2
+ ]
+
+
+def to_xyxy(
+ bbox: list[float],
+ box_format: Literal["xyxy", "xywh", "cxcywh"],
+ ) -> list[float]:
+ """
+ Change bounding box format to `"xyxy"`.
+
+ Parameters
+ ----------
+ bbox : list[float]
+ Input bounding box.
+ box_format : Literal["xyxy", "xywh", "cxcywh"]
+ Input bounding box format.
+ It can be `"xyxy"`, `"xywh"` or `"cxcywh"`.
+
+ Raises
+ ------
+ ValueError
+ If `box_format` not one of `"xyxy"`, `"xywh"`, `"cxcywh"`.
+
+ Returns
+ -------
+ list[float]
+ Bounding box in `"xyxy"` format.
+ """
+ if box_format == "xywh":
+ return xywh_xyxy(bbox)
+ if box_format == "xyxy":
+ return bbox
+ if box_format == "cxcywh":
+ return cxcywh_xyxy(bbox)
+ raise ValueError(
+ "`box_format` can be `'xyxy'`, `'xywh'`, `'cxcywh'`. "
+ f"Found {box_format}"
+ )
+
+
def get_suffix(
iou_threshold: np.ndarray,
area_range_key: np.ndarray,
diff --git a/tests/config.py b/tests/config.py
index 31e4111..973e9f3 100644
--- a/tests/config.py
+++ b/tests/config.py
@@ -776,6 +776,34 @@
"to_cover": {"pycoco_converter": False, "box_format_converter": False},
"ids": "annotations_exception_ypred_no_boxes"
},
+ {
+ "compute_settings": {"extended_summary": True},
+ "y_true": [
+ {"labels": [0, 2],
+ "boxes": [[17, 83, 97, 47], [57, 86, 96, 73]]}
+ ],
+ "y_pred": [
+ {
+ "labels": [0, 2],
+ "boxes": [[-17, -83, 0, 47], [0, 86, -96, 73]],
+ "scores": [.2, .3]}
+ ],
+ "ids": "annotations_with_no_valid_y_pred_boxes"
+ },
+ {
+ "compute_settings": {"extended_summary": True},
+ "y_true": [
+ {"labels": [0, 2],
+ "boxes": [[-17, -83, 0, 47], [0, 86, -96, 73]]}
+ ],
+ "y_pred": [
+ {
+ "labels": [0, 2],
+ "boxes": [[17, 83, 97, 47], [57, 86, 96, 73]],
+ "scores": [.2, .3]}
+ ],
+ "ids": "annotations_with_no_valid_y_true_boxes"
+ },
]
diff --git a/tests/test_odmetrics.py b/tests/test_odmetrics.py
index d118d79..d4b3702 100644
--- a/tests/test_odmetrics.py
+++ b/tests/test_odmetrics.py
@@ -6,11 +6,13 @@
import copy
from typing import Any, Literal
from functools import partial
+from itertools import product
import numpy as np
from parameterized import parameterized, parameterized_class
from src.od_metrics import ODMetrics, iou
from src.od_metrics.constants import DEFAULT_COCO
+from src.od_metrics.utils import to_xywh
from tests.utils import annotations_generator, pycoco_converter, \
test_equality, rename_dict, xywh_to, apply_function
from tests.config import TESTS
@@ -331,6 +333,65 @@ def test_pycoco_equivalence(
self.assertTrue(test_equality(od_metrics_ious, pycoco_ious))
+ @parameterized.expand(list(product(
+ ["random", None], ["xyxy", "xywh", "cxcywh", "error"])))
+ def test_box_formats(
+ self,
+ iscrowd_mode: Literal["random", None],
+ box_format: Literal["xyxy", "xywh", "cxcywh"],
+ ) -> None:
+ """Test `box_format` argument."""
+ if iscrowd_mode == "random":
+ iscrowd = list(map(
+ bool,
+ np.random.randint(
+ low=0,
+ high=2,
+ size=[self.SIZE]
+ ).tolist()
+ )
+ )
+ iscrowd_pycoco = iscrowd
+ else:
+ iscrowd = None
+ iscrowd_pycoco = [False] * self.SIZE
+ y_pred = np.random.randint(
+ low=1,
+ high=self.HIGH,
+ size=[self.SIZE, 4]
+ )
+
+ y_true = np.random.randint(
+ low=1,
+ high=self.HIGH,
+ size=[self.SIZE, 4]
+ )
+
+ if box_format in ["xyxy", "xywh", "cxcywh"]:
+ od_metrics_ious = iou(
+ y_true=y_true,
+ y_pred=y_pred,
+ iscrowd=iscrowd,
+ box_format=box_format,
+ )
+
+ y_true_pycoco = np.array([to_xywh(bbox_, box_format)
+ for bbox_ in y_true])
+ y_pred_pycoco = np.array([to_xywh(bbox_, box_format)
+ for bbox_ in y_pred])
+ pycoco_ious = maskUtils.iou(y_pred_pycoco, y_true_pycoco,
+ iscrowd_pycoco)
+
+ self.assertTrue(test_equality(od_metrics_ious, pycoco_ious))
+ else:
+ with self.assertRaises(ValueError):
+ od_metrics_ious = iou(
+ y_true=y_true,
+ y_pred=y_pred,
+ iscrowd=iscrowd,
+ box_format=box_format,
+ )
+
def test_length_exception(self) -> None:
"""Test exception `iscrowd` and `y_true` different length."""
iscrowd = list(map(
@@ -360,6 +421,55 @@ def test_length_exception(self) -> None:
iscrowd=iscrowd,
)
+ @parameterized.expand(list(product(
+ ["random", None], ["xyxy", "xywh", "cxcywh"])))
+ def test_not_valid_boxes(
+ self,
+ iscrowd_mode: Literal["random", None],
+ box_format: Literal["xyxy", "xywh", "cxcywh"],
+ ) -> None:
+ """Test equivalence for not valid (negative values) boxes."""
+ if iscrowd_mode == "random":
+ iscrowd = list(map(
+ bool,
+ np.random.randint(
+ low=0,
+ high=2,
+ size=[self.SIZE]
+ ).tolist()
+ )
+ )
+ iscrowd_pycoco = iscrowd
+ else:
+ iscrowd = None
+ iscrowd_pycoco = [False] * self.SIZE
+ y_pred = np.random.randint(
+ low=-10,
+ high=1,
+ size=[self.SIZE, 4]
+ )
+ y_true = np.random.randint(
+ low=-10,
+ high=1,
+ size=[self.SIZE, 4]
+ )
+
+ od_metrics_ious = iou(
+ y_true=y_true,
+ y_pred=y_pred,
+ iscrowd=iscrowd,
+ box_format=box_format,
+ )
+
+ y_true_pycoco = np.array([to_xywh(bbox_, box_format)
+ for bbox_ in y_true])
+ y_pred_pycoco = np.array([to_xywh(bbox_, box_format)
+ for bbox_ in y_pred])
+ pycoco_ious = maskUtils.iou(y_pred_pycoco, y_true_pycoco,
+ iscrowd_pycoco)
+
+ self.assertTrue(test_equality(od_metrics_ious, pycoco_ious))
+
if __name__ == "__main__":
unittest.main() # pragma: no cover