diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2ca4357cb..8ddaa08e9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -32,7 +32,7 @@ repos:
additional_dependencies: ["bandit[toml]"]
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.6.9
+ rev: v0.7.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
diff --git a/docs/metrics/common_values.md b/docs/metrics/common_values.md
new file mode 100644
index 000000000..b7600f3f1
--- /dev/null
+++ b/docs/metrics/common_values.md
@@ -0,0 +1,20 @@
+---
+comments: true
+status: new
+---
+
+# Common Values
+
+This page contains supplementary values, types and enums that metrics use.
+
+
+
+:::supervision.metrics.core.MetricTarget
+
+
+
+:::supervision.metrics.core.AveragingMethod
diff --git a/docs/metrics/precision.md b/docs/metrics/precision.md
new file mode 100644
index 000000000..c704452ee
--- /dev/null
+++ b/docs/metrics/precision.md
@@ -0,0 +1,18 @@
+---
+comments: true
+status: new
+---
+
+# Precision
+
+
+
+:::supervision.metrics.precision.Precision
+
+
+
+:::supervision.metrics.precision.PrecisionResult
diff --git a/docs/metrics/recall.md b/docs/metrics/recall.md
new file mode 100644
index 000000000..78dde8334
--- /dev/null
+++ b/docs/metrics/recall.md
@@ -0,0 +1,18 @@
+---
+comments: true
+status: new
+---
+
+# Recall
+
+
+
+:::supervision.metrics.recall.Recall
+
+
+
+:::supervision.metrics.recall.RecallResult
diff --git a/mkdocs.yml b/mkdocs.yml
index 3cd867590..b30dbcfcc 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -66,7 +66,10 @@ nav:
- Utils: datasets/utils.md
- Metrics:
- mAP: metrics/mean_average_precision.md
+ - Precision: metrics/precision.md
+ - Recall: metrics/recall.md
- F1 Score: metrics/f1_score.md
+ - Common Values: metrics/common_values.md
- Legacy Metrics: detection/metrics.md
- Utils:
- Video: utils/video.md
diff --git a/poetry.lock b/poetry.lock
index 797b71397..0a010aa4d 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2371,13 +2371,13 @@ requests = "*"
[[package]]
name = "mkdocs-git-revision-date-localized-plugin"
-version = "1.2.9"
+version = "1.3.0"
description = "Mkdocs plugin that enables displaying the localized date of the last git modification of a markdown file."
optional = false
python-versions = ">=3.8"
files = [
- {file = "mkdocs_git_revision_date_localized_plugin-1.2.9-py3-none-any.whl", hash = "sha256:dea5c8067c23df30275702a1708885500fadf0abfb595b60e698bffc79c7a423"},
- {file = "mkdocs_git_revision_date_localized_plugin-1.2.9.tar.gz", hash = "sha256:df9a50873fba3a42ce9123885f8c53d589e90ef6c2443fe3280ef1e8d33c8f65"},
+ {file = "mkdocs_git_revision_date_localized_plugin-1.3.0-py3-none-any.whl", hash = "sha256:c99377ee119372d57a9e47cff4e68f04cce634a74831c06bc89b33e456e840a1"},
+ {file = "mkdocs_git_revision_date_localized_plugin-1.3.0.tar.gz", hash = "sha256:439e2f14582204050a664c258861c325064d97cdc848c541e48bb034a6c4d0cb"},
]
[package.dependencies]
@@ -2412,13 +2412,13 @@ pygments = ">2.12.0"
[[package]]
name = "mkdocs-material"
-version = "9.5.41"
+version = "9.5.42"
description = "Documentation that simply works"
optional = false
python-versions = ">=3.8"
files = [
- {file = "mkdocs_material-9.5.41-py3-none-any.whl", hash = "sha256:990bc138c33342b5b73e7545915ebc0136e501bfbd8e365735144f5120891d83"},
- {file = "mkdocs_material-9.5.41.tar.gz", hash = "sha256:30fa5d459b4b8130848ecd8e1c908878345d9d8268f7ddbc31eebe88d462d97b"},
+ {file = "mkdocs_material-9.5.42-py3-none-any.whl", hash = "sha256:452a7c5d21284b373f36b981a2cbebfff59263feebeede1bc28652e9c5bbe316"},
+ {file = "mkdocs_material-9.5.42.tar.gz", hash = "sha256:92779b5e9b5934540c574c11647131d217dc540dce72b05feeda088c8eb1b8f2"},
]
[package.dependencies]
@@ -2509,43 +2509,43 @@ files = [
[[package]]
name = "mypy"
-version = "1.12.0"
+version = "1.13.0"
description = "Optional static typing for Python"
optional = false
python-versions = ">=3.8"
files = [
- {file = "mypy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4397081e620dc4dc18e2f124d5e1d2c288194c2c08df6bdb1db31c38cd1fe1ed"},
- {file = "mypy-1.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:684a9c508a283f324804fea3f0effeb7858eb03f85c4402a967d187f64562469"},
- {file = "mypy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cabe4cda2fa5eca7ac94854c6c37039324baaa428ecbf4de4567279e9810f9e"},
- {file = "mypy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:060a07b10e999ac9e7fa249ce2bdcfa9183ca2b70756f3bce9df7a92f78a3c0a"},
- {file = "mypy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:0eff042d7257f39ba4ca06641d110ca7d2ad98c9c1fb52200fe6b1c865d360ff"},
- {file = "mypy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b86de37a0da945f6d48cf110d5206c5ed514b1ca2614d7ad652d4bf099c7de7"},
- {file = "mypy-1.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20c7c5ce0c1be0b0aea628374e6cf68b420bcc772d85c3c974f675b88e3e6e57"},
- {file = "mypy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64ee25f05fc2d3d8474985c58042b6759100a475f8237da1f4faf7fcd7e6309"},
- {file = "mypy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:faca7ab947c9f457a08dcb8d9a8664fd438080e002b0fa3e41b0535335edcf7f"},
- {file = "mypy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:5bc81701d52cc8767005fdd2a08c19980de9ec61a25dbd2a937dfb1338a826f9"},
- {file = "mypy-1.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8462655b6694feb1c99e433ea905d46c478041a8b8f0c33f1dab00ae881b2164"},
- {file = "mypy-1.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:923ea66d282d8af9e0f9c21ffc6653643abb95b658c3a8a32dca1eff09c06475"},
- {file = "mypy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1ebf9e796521f99d61864ed89d1fb2926d9ab6a5fab421e457cd9c7e4dd65aa9"},
- {file = "mypy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e478601cc3e3fa9d6734d255a59c7a2e5c2934da4378f3dd1e3411ea8a248642"},
- {file = "mypy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:c72861b7139a4f738344faa0e150834467521a3fba42dc98264e5aa9507dd601"},
- {file = "mypy-1.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52b9e1492e47e1790360a43755fa04101a7ac72287b1a53ce817f35899ba0521"},
- {file = "mypy-1.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:48d3e37dd7d9403e38fa86c46191de72705166d40b8c9f91a3de77350daa0893"},
- {file = "mypy-1.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f106db5ccb60681b622ac768455743ee0e6a857724d648c9629a9bd2ac3f721"},
- {file = "mypy-1.12.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:233e11b3f73ee1f10efada2e6da0f555b2f3a5316e9d8a4a1224acc10e7181d3"},
- {file = "mypy-1.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ae8959c21abcf9d73aa6c74a313c45c0b5a188752bf37dace564e29f06e9c1b"},
- {file = "mypy-1.12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:eafc1b7319b40ddabdc3db8d7d48e76cfc65bbeeafaa525a4e0fa6b76175467f"},
- {file = "mypy-1.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9b9ce1ad8daeb049c0b55fdb753d7414260bad8952645367e70ac91aec90e07e"},
- {file = "mypy-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfe012b50e1491d439172c43ccb50db66d23fab714d500b57ed52526a1020bb7"},
- {file = "mypy-1.12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2c40658d4fa1ab27cb53d9e2f1066345596af2f8fe4827defc398a09c7c9519b"},
- {file = "mypy-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:dee78a8b9746c30c1e617ccb1307b351ded57f0de0d287ca6276378d770006c0"},
- {file = "mypy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b5df6c8a8224f6b86746bda716bbe4dbe0ce89fd67b1fa4661e11bfe38e8ec8"},
- {file = "mypy-1.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5feee5c74eb9749e91b77f60b30771563327329e29218d95bedbe1257e2fe4b0"},
- {file = "mypy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:77278e8c6ffe2abfba6db4125de55f1024de9a323be13d20e4f73b8ed3402bd1"},
- {file = "mypy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:dcfb754dea911039ac12434d1950d69a2f05acd4d56f7935ed402be09fad145e"},
- {file = "mypy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:06de0498798527451ffb60f68db0d368bd2bae2bbfb5237eae616d4330cc87aa"},
- {file = "mypy-1.12.0-py3-none-any.whl", hash = "sha256:fd313226af375d52e1e36c383f39bf3836e1f192801116b31b090dfcd3ec5266"},
- {file = "mypy-1.12.0.tar.gz", hash = "sha256:65a22d87e757ccd95cbbf6f7e181e6caa87128255eb2b6be901bb71b26d8a99d"},
+ {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"},
+ {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"},
+ {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"},
+ {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"},
+ {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"},
+ {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"},
+ {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"},
+ {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"},
+ {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"},
+ {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"},
+ {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"},
+ {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"},
+ {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"},
+ {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"},
+ {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"},
+ {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"},
+ {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"},
+ {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"},
+ {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"},
+ {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"},
+ {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"},
+ {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"},
+ {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"},
+ {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"},
+ {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"},
+ {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"},
+ {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"},
+ {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"},
+ {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"},
+ {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"},
+ {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"},
+ {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"},
]
[package.dependencies]
@@ -2555,6 +2555,7 @@ typing-extensions = ">=4.6.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
+faster-cache = ["orjson"]
install-types = ["pip"]
mypyc = ["setuptools (>=50)"]
reports = ["lxml"]
@@ -4036,29 +4037,29 @@ files = [
[[package]]
name = "ruff"
-version = "0.6.9"
+version = "0.7.0"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
- {file = "ruff-0.6.9-py3-none-linux_armv6l.whl", hash = "sha256:064df58d84ccc0ac0fcd63bc3090b251d90e2a372558c0f057c3f75ed73e1ccd"},
- {file = "ruff-0.6.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:140d4b5c9f5fc7a7b074908a78ab8d384dd7f6510402267bc76c37195c02a7ec"},
- {file = "ruff-0.6.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53fd8ca5e82bdee8da7f506d7b03a261f24cd43d090ea9db9a1dc59d9313914c"},
- {file = "ruff-0.6.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645d7d8761f915e48a00d4ecc3686969761df69fb561dd914a773c1a8266e14e"},
- {file = "ruff-0.6.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eae02b700763e3847595b9d2891488989cac00214da7f845f4bcf2989007d577"},
- {file = "ruff-0.6.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d5ccc9e58112441de8ad4b29dcb7a86dc25c5f770e3c06a9d57e0e5eba48829"},
- {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:417b81aa1c9b60b2f8edc463c58363075412866ae4e2b9ab0f690dc1e87ac1b5"},
- {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c866b631f5fbce896a74a6e4383407ba7507b815ccc52bcedabb6810fdb3ef7"},
- {file = "ruff-0.6.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b118afbb3202f5911486ad52da86d1d52305b59e7ef2031cea3425142b97d6f"},
- {file = "ruff-0.6.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67267654edc23c97335586774790cde402fb6bbdb3c2314f1fc087dee320bfa"},
- {file = "ruff-0.6.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3ef0cc774b00fec123f635ce5c547dac263f6ee9fb9cc83437c5904183b55ceb"},
- {file = "ruff-0.6.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:12edd2af0c60fa61ff31cefb90aef4288ac4d372b4962c2864aeea3a1a2460c0"},
- {file = "ruff-0.6.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:55bb01caeaf3a60b2b2bba07308a02fca6ab56233302406ed5245180a05c5625"},
- {file = "ruff-0.6.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:925d26471fa24b0ce5a6cdfab1bb526fb4159952385f386bdcc643813d472039"},
- {file = "ruff-0.6.9-py3-none-win32.whl", hash = "sha256:eb61ec9bdb2506cffd492e05ac40e5bc6284873aceb605503d8494180d6fc84d"},
- {file = "ruff-0.6.9-py3-none-win_amd64.whl", hash = "sha256:785d31851c1ae91f45b3d8fe23b8ae4b5170089021fbb42402d811135f0b7117"},
- {file = "ruff-0.6.9-py3-none-win_arm64.whl", hash = "sha256:a9641e31476d601f83cd602608739a0840e348bda93fec9f1ee816f8b6798b93"},
- {file = "ruff-0.6.9.tar.gz", hash = "sha256:b076ef717a8e5bc819514ee1d602bbdca5b4420ae13a9cf61a0c0a4f53a2baa2"},
+ {file = "ruff-0.7.0-py3-none-linux_armv6l.whl", hash = "sha256:0cdf20c2b6ff98e37df47b2b0bd3a34aaa155f59a11182c1303cce79be715628"},
+ {file = "ruff-0.7.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:496494d350c7fdeb36ca4ef1c9f21d80d182423718782222c29b3e72b3512737"},
+ {file = "ruff-0.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:214b88498684e20b6b2b8852c01d50f0651f3cc6118dfa113b4def9f14faaf06"},
+ {file = "ruff-0.7.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630fce3fefe9844e91ea5bbf7ceadab4f9981f42b704fae011bb8efcaf5d84be"},
+ {file = "ruff-0.7.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:211d877674e9373d4bb0f1c80f97a0201c61bcd1e9d045b6e9726adc42c156aa"},
+ {file = "ruff-0.7.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:194d6c46c98c73949a106425ed40a576f52291c12bc21399eb8f13a0f7073495"},
+ {file = "ruff-0.7.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:82c2579b82b9973a110fab281860403b397c08c403de92de19568f32f7178598"},
+ {file = "ruff-0.7.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9af971fe85dcd5eaed8f585ddbc6bdbe8c217fb8fcf510ea6bca5bdfff56040e"},
+ {file = "ruff-0.7.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b641c7f16939b7d24b7bfc0be4102c56562a18281f84f635604e8a6989948914"},
+ {file = "ruff-0.7.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d71672336e46b34e0c90a790afeac8a31954fd42872c1f6adaea1dff76fd44f9"},
+ {file = "ruff-0.7.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ab7d98c7eed355166f367597e513a6c82408df4181a937628dbec79abb2a1fe4"},
+ {file = "ruff-0.7.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1eb54986f770f49edb14f71d33312d79e00e629a57387382200b1ef12d6a4ef9"},
+ {file = "ruff-0.7.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:dc452ba6f2bb9cf8726a84aa877061a2462afe9ae0ea1d411c53d226661c601d"},
+ {file = "ruff-0.7.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4b406c2dce5be9bad59f2de26139a86017a517e6bcd2688da515481c05a2cb11"},
+ {file = "ruff-0.7.0-py3-none-win32.whl", hash = "sha256:f6c968509f767776f524a8430426539587d5ec5c662f6addb6aa25bc2e8195ec"},
+ {file = "ruff-0.7.0-py3-none-win_amd64.whl", hash = "sha256:ff4aabfbaaba880e85d394603b9e75d32b0693152e16fa659a3064a85df7fce2"},
+ {file = "ruff-0.7.0-py3-none-win_arm64.whl", hash = "sha256:10842f69c245e78d6adec7e1db0a7d9ddc2fff0621d730e61657b64fa36f207e"},
+ {file = "ruff-0.7.0.tar.gz", hash = "sha256:47a86360cf62d9cd53ebfb0b5eb0e882193fc191c6d717e8bef4462bc3b9ea2b"},
]
[[package]]
@@ -4377,13 +4378,13 @@ files = [
[[package]]
name = "tox"
-version = "4.22.0"
+version = "4.23.2"
description = "tox is a generic virtualenv management and test command line tool"
optional = false
python-versions = ">=3.8"
files = [
- {file = "tox-4.22.0-py3-none-any.whl", hash = "sha256:03734d9a9ac138cd1a898a372fb1b8079e2728618ae06dc37cbf3686cfb56eea"},
- {file = "tox-4.22.0.tar.gz", hash = "sha256:acc6c627cb3316585238d55d2b633e132fea1bdb01b9d93b56bce7caea6ae73d"},
+ {file = "tox-4.23.2-py3-none-any.whl", hash = "sha256:452bc32bb031f2282881a2118923176445bac783ab97c874b8770ab4c3b76c38"},
+ {file = "tox-4.23.2.tar.gz", hash = "sha256:86075e00e555df6e82e74cfc333917f91ecb47ffbc868dcafbd2672e332f4a2c"},
]
[package.dependencies]
@@ -4399,6 +4400,9 @@ tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""}
typing-extensions = {version = ">=4.12.2", markers = "python_version < \"3.11\""}
virtualenv = ">=20.26.6"
+[package.extras]
+test = ["devpi-process (>=1.0.2)", "pytest (>=8.3.3)", "pytest-mock (>=3.14)"]
+
[[package]]
name = "tqdm"
version = "4.66.5"
diff --git a/supervision/detection/tools/polygon_zone.py b/supervision/detection/tools/polygon_zone.py
index b8b728fc5..f69f3c9fe 100644
--- a/supervision/detection/tools/polygon_zone.py
+++ b/supervision/detection/tools/polygon_zone.py
@@ -17,6 +17,12 @@ class PolygonZone:
"""
A class for defining a polygon-shaped zone within a frame for detecting objects.
+ !!! warning
+
+ LineZone uses the `tracker_id`. Read
+ [here](/latest/trackers/) to learn how to plug
+ tracking into your inference pipeline.
+
Attributes:
polygon (np.ndarray): A polygon represented by a numpy array of shape
`(N, 2)`, containing the `x`, `y` coordinates of the points.
@@ -26,6 +32,28 @@ class PolygonZone:
(default: (sv.Position.BOTTOM_CENTER,)).
current_count (int): The current count of detected objects within the zone
mask (np.ndarray): The 2D bool mask for the polygon zone
+
+ Example:
+ ```python
+ import supervision as sv
+ from ultralytics import YOLO
+ import numpy as np
+ import cv2
+
+ image = cv2.imread()
+ model = YOLO("yolo11s")
+ tracker = sv.ByteTrack()
+
+ polygon = np.array([[100, 200], [200, 100], [300, 200], [200, 300]])
+ polygon_zone = sv.PolygonZone(polygon=polygon)
+
+ result = model.infer(image)[0]
+ detections = sv.Detections.from_ultralytics(result)
+ detections = tracker.update_with_detections(detections)
+
+ is_detections_in_zone = polygon_zone.trigger(detections)
+ print(polygon_zone.current_count)
+ ```
"""
def __init__(
diff --git a/supervision/geometry/utils.py b/supervision/geometry/utils.py
index 8a0ca35c5..2247adc50 100644
--- a/supervision/geometry/utils.py
+++ b/supervision/geometry/utils.py
@@ -16,6 +16,9 @@ def get_polygon_center(polygon: np.ndarray) -> Point:
Point: The center of the polygon, represented as a
Point object with x and y attributes.
+ Raises:
+ ValueError: If the polygon has no vertices.
+
Examples:
```python
import numpy as np
@@ -30,6 +33,9 @@ def get_polygon_center(polygon: np.ndarray) -> Point:
# This is one of the 3 candidate algorithms considered for centroid calculation.
# For a more detailed discussion, see PR #1084 and commit eb33176
+ if len(polygon) == 0:
+ raise ValueError("Polygon must have at least one vertex.")
+
shift_polygon = np.roll(polygon, -1, axis=0)
signed_areas = np.cross(polygon, shift_polygon) / 2
if signed_areas.sum() == 0:
diff --git a/supervision/metrics/__init__.py b/supervision/metrics/__init__.py
index 8ae33e639..90fc17b47 100644
--- a/supervision/metrics/__init__.py
+++ b/supervision/metrics/__init__.py
@@ -8,6 +8,8 @@
MeanAveragePrecision,
MeanAveragePrecisionResult,
)
+from supervision.metrics.precision import Precision, PrecisionResult
+from supervision.metrics.recall import Recall, RecallResult
from supervision.metrics.utils.object_size import (
ObjectSizeCategory,
get_detection_size_category,
diff --git a/supervision/metrics/core.py b/supervision/metrics/core.py
index d1818441e..def5999a0 100644
--- a/supervision/metrics/core.py
+++ b/supervision/metrics/core.py
@@ -37,9 +37,10 @@ class MetricTarget(Enum):
"""
Specifies what type of detection is used to compute the metric.
- * BOXES: xyxy bounding boxes
- * MASKS: Binary masks
- * ORIENTED_BOUNDING_BOXES: Oriented bounding boxes (OBB)
+ Attributes:
+ BOXES: xyxy bounding boxes
+ MASKS: Binary masks
+ ORIENTED_BOUNDING_BOXES: Oriented bounding boxes (OBB)
"""
BOXES = "boxes"
@@ -54,15 +55,16 @@ class AveragingMethod(Enum):
Suppose, before returning the final result, a metric is computed for each class.
How do you combine those to get the final number?
- * MACRO: Calculate the metric for each class and average the results. The simplest
- averaging method, but it does not take class imbalance into account.
- * MICRO: Calculate the metric globally by counting the total true positives, false
- positives, and false negatives. Micro averaging is useful when you want to give
- more importance to classes with more samples. It's also more appropriate if you
- have an imbalance in the number of instances per class.
- * WEIGHTED: Calculate the metric for each class and average the results, weighted by
- the number of true instances of each class. Use weighted averaging if you want
- to take class imbalance into account.
+ Attributes:
+ MACRO: Calculate the metric for each class and average the results. The simplest
+ averaging method, but it does not take class imbalance into account.
+ MICRO: Calculate the metric globally by counting the total true positives, false
+ positives, and false negatives. Micro averaging is useful when you want to
+ give more importance to classes with more samples. It's also more
+ appropriate if you have an imbalance in the number of instances per class.
+ WEIGHTED: Calculate the metric for each class and average the results, weighted
+ by the number of true instances of each class. Use weighted averaging if
+ you want to take class imbalance into account.
"""
MACRO = "macro"
diff --git a/supervision/metrics/f1_score.py b/supervision/metrics/f1_score.py
index 2ca5bca5c..ba4fcd59a 100644
--- a/supervision/metrics/f1_score.py
+++ b/supervision/metrics/f1_score.py
@@ -23,11 +23,45 @@
class F1Score(Metric):
+ """
+ F1 Score is a metric used to evaluate object detection models. It is the harmonic
+ mean of precision and recall, calculated at different IoU thresholds.
+
+ In simple terms, F1 Score is a measure of a model's balance between precision and
+ recall (accuracy and completeness), calculated as:
+
+ `F1 = 2 * (precision * recall) / (precision + recall)`
+
+ Example:
+ ```python
+ import supervision as sv
+ from supervision.metrics import F1Score
+
+ predictions = sv.Detections(...)
+ targets = sv.Detections(...)
+
+ f1_metric = F1Score()
+ f1_result = f1_metric.update(predictions, targets).compute()
+
+ print(f1_result)
+ print(f1_result.f1_50)
+ print(f1_result.small_objects.f1_50)
+ ```
+ """
+
def __init__(
self,
metric_target: MetricTarget = MetricTarget.BOXES,
averaging_method: AveragingMethod = AveragingMethod.WEIGHTED,
):
+ """
+ Initialize the F1Score metric.
+
+ Args:
+ metric_target (MetricTarget): The type of detection data to use.
+ averaging_method (AveragingMethod): The averaging method used to compute the
+ F1 scores. Determines how the F1 scores are aggregated across classes.
+ """
self._metric_target = metric_target
if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
raise NotImplementedError(
@@ -40,6 +74,9 @@ def __init__(
self._targets_list: List[Detections] = []
def reset(self) -> None:
+ """
+ Reset the metric to its initial state, clearing all stored data.
+ """
self._predictions_list = []
self._targets_list = []
@@ -48,6 +85,16 @@ def update(
predictions: Union[Detections, List[Detections]],
targets: Union[Detections, List[Detections]],
) -> F1Score:
+ """
+ Add new predictions and targets to the metric, but do not compute the result.
+
+ Args:
+ predictions (Union[Detections, List[Detections]]): The predicted detections.
+ targets (Union[Detections, List[Detections]]): The target detections.
+
+ Returns:
+ (F1Score): The updated metric instance.
+ """
if not isinstance(predictions, list):
predictions = [predictions]
if not isinstance(targets, list):
@@ -65,6 +112,13 @@ def update(
return self
def compute(self) -> F1ScoreResult:
+ """
+ Calculate the F1 score metric based on the stored predictions and ground-truth
+ data, at different IoU thresholds.
+
+ Returns:
+ (F1ScoreResult): The F1 score metric result.
+ """
result = self._compute(self._predictions_list, self._targets_list)
small_predictions, small_targets = self._filter_predictions_and_targets_by_size(
@@ -373,7 +427,6 @@ class F1ScoreResult:
The results of the F1 score metric calculation.
Defaults to `0` if no detections or targets were provided.
- Provides a custom `__str__` method for pretty printing.
Attributes:
metric_target (MetricTarget): the type of data used for the metric -
diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py
index dbd60b2e7..8cec50c85 100644
--- a/supervision/metrics/mean_average_precision.py
+++ b/supervision/metrics/mean_average_precision.py
@@ -23,6 +23,27 @@
class MeanAveragePrecision(Metric):
+ """
+ Mean Average Precision (mAP) is a metric used to evaluate object detection models.
+ It is the average of the precision-recall curves at different IoU thresholds.
+
+ Example:
+ ```python
+ import supervision as sv
+ from supervision.metrics import MeanAveragePrecision
+
+ predictions = sv.Detections(...)
+ targets = sv.Detections(...)
+
+ map_metric = MeanAveragePrecision()
+ map_result = map_metric.update(predictions, targets).compute()
+
+ print(map_result)
+ print(map_result.map50_95)
+ map_result.plot()
+ ```
+ """
+
def __init__(
self,
metric_target: MetricTarget = MetricTarget.BOXES,
@@ -47,6 +68,9 @@ def __init__(
self._targets_list: List[Detections] = []
def reset(self) -> None:
+ """
+ Reset the metric to its initial state, clearing all stored data.
+ """
self._predictions_list = []
self._targets_list = []
@@ -95,26 +119,10 @@ def compute(
) -> MeanAveragePrecisionResult:
"""
Calculate Mean Average Precision based on predicted and ground-truth
- detections at different thresholds.
+ detections at different thresholds.
Returns:
- (MeanAveragePrecisionResult): New instance of MeanAveragePrecision.
-
- Example:
- ```python
- import supervision as sv
- from supervision.metrics import MeanAveragePrecision
-
- predictions = sv.Detections(...)
- targets = sv.Detections(...)
-
- map_metric = MeanAveragePrecision()
- map_result = map_metric.update(predictions, targets).compute()
-
- print(map_result)
- print(map_result.map50_95)
- map_result.plot()
- ```
+ (MeanAveragePrecisionResult): The Mean Average Precision result.
"""
result = self._compute(self._predictions_list, self._targets_list)
diff --git a/supervision/metrics/precision.py b/supervision/metrics/precision.py
new file mode 100644
index 000000000..d915e1f49
--- /dev/null
+++ b/supervision/metrics/precision.py
@@ -0,0 +1,603 @@
+from __future__ import annotations
+
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import numpy as np
+from matplotlib import pyplot as plt
+
+from supervision.config import ORIENTED_BOX_COORDINATES
+from supervision.detection.core import Detections
+from supervision.detection.utils import box_iou_batch, mask_iou_batch
+from supervision.draw.color import LEGACY_COLOR_PALETTE
+from supervision.metrics.core import AveragingMethod, Metric, MetricTarget
+from supervision.metrics.utils.object_size import (
+ ObjectSizeCategory,
+ get_detection_size_category,
+)
+from supervision.metrics.utils.utils import ensure_pandas_installed
+
+if TYPE_CHECKING:
+ import pandas as pd
+
+
+class Precision(Metric):
+ """
+ Precision is a metric used to evaluate object detection models. It is the ratio of
+ true positive detections to the total number of predicted detections. We calculate
+ it at different IoU thresholds.
+
+ In simple terms, Precision is a measure of a model's accuracy, calculated as:
+
+ `Precision = TP / (TP + FP)`
+
+ Here, `TP` is the number of true positives (correct detections), and `FP` is the
+ number of false positive detections (detected, but incorrectly).
+
+ Example:
+ ```python
+ import supervision as sv
+ from supervision.metrics import Precision
+
+ predictions = sv.Detections(...)
+ targets = sv.Detections(...)
+
+ precision_metric = Precision()
+ precision_result = precision_metric.update(predictions, targets).compute()
+
+ print(precision_result)
+ print(precision_result.precision_at_50)
+ print(precision_result.small_objects.precision_at_50)
+ ```
+ """
+
+ def __init__(
+ self,
+ metric_target: MetricTarget = MetricTarget.BOXES,
+ averaging_method: AveragingMethod = AveragingMethod.WEIGHTED,
+ ):
+ """
+ Initialize the Precision metric.
+
+ Args:
+ metric_target (MetricTarget): The type of detection data to use.
+ averaging_method (AveragingMethod): The averaging method used to compute the
+ precision. Determines how the precision is aggregated across classes.
+ """
+ self._metric_target = metric_target
+ if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
+ raise NotImplementedError(
+ "Precision is not implemented for oriented bounding boxes."
+ )
+
+ self._metric_target = metric_target
+ self.averaging_method = averaging_method
+ self._predictions_list: List[Detections] = []
+ self._targets_list: List[Detections] = []
+
+ def reset(self) -> None:
+ """
+ Reset the metric to its initial state, clearing all stored data.
+ """
+ self._predictions_list = []
+ self._targets_list = []
+
+ def update(
+ self,
+ predictions: Union[Detections, List[Detections]],
+ targets: Union[Detections, List[Detections]],
+ ) -> Precision:
+ """
+ Add new predictions and targets to the metric, but do not compute the result.
+
+ Args:
+ predictions (Union[Detections, List[Detections]]): The predicted detections.
+ targets (Union[Detections, List[Detections]]): The target detections.
+
+ Returns:
+ (Precision): The updated metric instance.
+ """
+ if not isinstance(predictions, list):
+ predictions = [predictions]
+ if not isinstance(targets, list):
+ targets = [targets]
+
+ if len(predictions) != len(targets):
+ raise ValueError(
+ f"The number of predictions ({len(predictions)}) and"
+ f" targets ({len(targets)}) during the update must be the same."
+ )
+
+ self._predictions_list.extend(predictions)
+ self._targets_list.extend(targets)
+
+ return self
+
+ def compute(self) -> PrecisionResult:
+ """
+ Calculate the precision metric based on the stored predictions and ground-truth
+ data, at different IoU thresholds.
+
+ Returns:
+ (PrecisionResult): The precision metric result.
+ """
+ result = self._compute(self._predictions_list, self._targets_list)
+
+ small_predictions, small_targets = self._filter_predictions_and_targets_by_size(
+ self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL
+ )
+ result.small_objects = self._compute(small_predictions, small_targets)
+
+ medium_predictions, medium_targets = (
+ self._filter_predictions_and_targets_by_size(
+ self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM
+ )
+ )
+ result.medium_objects = self._compute(medium_predictions, medium_targets)
+
+ large_predictions, large_targets = self._filter_predictions_and_targets_by_size(
+ self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE
+ )
+ result.large_objects = self._compute(large_predictions, large_targets)
+
+ return result
+
+ def _compute(
+ self, predictions_list: List[Detections], targets_list: List[Detections]
+ ) -> PrecisionResult:
+ iou_thresholds = np.linspace(0.5, 0.95, 10)
+ stats = []
+
+ for predictions, targets in zip(predictions_list, targets_list):
+ prediction_contents = self._detections_content(predictions)
+ target_contents = self._detections_content(targets)
+
+ if len(targets) > 0:
+ if len(predictions) == 0:
+ stats.append(
+ (
+ np.zeros((0, iou_thresholds.size), dtype=bool),
+ np.zeros((0,), dtype=np.float32),
+ np.zeros((0,), dtype=int),
+ targets.class_id,
+ )
+ )
+
+ else:
+ if self._metric_target == MetricTarget.BOXES:
+ iou = box_iou_batch(target_contents, prediction_contents)
+ elif self._metric_target == MetricTarget.MASKS:
+ iou = mask_iou_batch(target_contents, prediction_contents)
+ else:
+ raise NotImplementedError(
+ "Unsupported metric target for IoU calculation"
+ )
+
+ matches = self._match_detection_batch(
+ predictions.class_id, targets.class_id, iou, iou_thresholds
+ )
+ stats.append(
+ (
+ matches,
+ predictions.confidence,
+ predictions.class_id,
+ targets.class_id,
+ )
+ )
+
+ if not stats:
+ return PrecisionResult(
+ metric_target=self._metric_target,
+ averaging_method=self.averaging_method,
+ precision_scores=np.zeros(iou_thresholds.shape[0]),
+ precision_per_class=np.zeros((0, iou_thresholds.shape[0])),
+ iou_thresholds=iou_thresholds,
+ matched_classes=np.array([], dtype=int),
+ small_objects=None,
+ medium_objects=None,
+ large_objects=None,
+ )
+
+ concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
+ precision_scores, precision_per_class, unique_classes = (
+ self._compute_precision_for_classes(*concatenated_stats)
+ )
+
+ return PrecisionResult(
+ metric_target=self._metric_target,
+ averaging_method=self.averaging_method,
+ precision_scores=precision_scores,
+ precision_per_class=precision_per_class,
+ iou_thresholds=iou_thresholds,
+ matched_classes=unique_classes,
+ small_objects=None,
+ medium_objects=None,
+ large_objects=None,
+ )
+
+ def _compute_precision_for_classes(
+ self,
+ matches: np.ndarray,
+ prediction_confidence: np.ndarray,
+ prediction_class_ids: np.ndarray,
+ true_class_ids: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ sorted_indices = np.argsort(-prediction_confidence)
+ matches = matches[sorted_indices]
+ prediction_class_ids = prediction_class_ids[sorted_indices]
+ unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
+
+ # Shape: PxTh,P,C,C -> CxThx3
+ confusion_matrix = self._compute_confusion_matrix(
+ matches, prediction_class_ids, unique_classes, class_counts
+ )
+
+ # Shape: CxThx3 -> CxTh
+ precision_per_class = self._compute_precision(confusion_matrix)
+
+ # Shape: CxTh -> Th
+ if self.averaging_method == AveragingMethod.MACRO:
+ precision_scores = np.mean(precision_per_class, axis=0)
+ elif self.averaging_method == AveragingMethod.MICRO:
+ confusion_matrix_merged = confusion_matrix.sum(0)
+ precision_scores = self._compute_precision(confusion_matrix_merged)
+ elif self.averaging_method == AveragingMethod.WEIGHTED:
+ class_counts = class_counts.astype(np.float32)
+ precision_scores = np.average(
+ precision_per_class, axis=0, weights=class_counts
+ )
+
+ return precision_scores, precision_per_class, unique_classes
+
+ @staticmethod
+ def _match_detection_batch(
+ predictions_classes: np.ndarray,
+ target_classes: np.ndarray,
+ iou: np.ndarray,
+ iou_thresholds: np.ndarray,
+ ) -> np.ndarray:
+ num_predictions, num_iou_levels = (
+ predictions_classes.shape[0],
+ iou_thresholds.shape[0],
+ )
+ correct = np.zeros((num_predictions, num_iou_levels), dtype=bool)
+ correct_class = target_classes[:, None] == predictions_classes
+
+ for i, iou_level in enumerate(iou_thresholds):
+ matched_indices = np.where((iou >= iou_level) & correct_class)
+
+ if matched_indices[0].shape[0]:
+ combined_indices = np.stack(matched_indices, axis=1)
+ iou_values = iou[matched_indices][:, None]
+ matches = np.hstack([combined_indices, iou_values])
+
+ if matched_indices[0].shape[0] > 1:
+ matches = matches[matches[:, 2].argsort()[::-1]]
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
+
+ correct[matches[:, 1].astype(int), i] = True
+
+ return correct
+
+ @staticmethod
+ def _compute_confusion_matrix(
+ sorted_matches: np.ndarray,
+ sorted_prediction_class_ids: np.ndarray,
+ unique_classes: np.ndarray,
+ class_counts: np.ndarray,
+ ) -> np.ndarray:
+ """
+ Compute the confusion matrix for each class and IoU threshold.
+
+ Assumes the matches and prediction_class_ids are sorted by confidence
+ in descending order.
+
+ Arguments:
+ sorted_matches: np.ndarray, bool, shape (P, Th), that is True
+ if the prediction is a true positive at the given IoU threshold.
+ sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing
+ the class id for each prediction.
+ unique_classes: np.ndarray, int, shape (C,), containing the unique
+ class ids.
+ class_counts: np.ndarray, int, shape (C,), containing the number
+ of true instances for each class.
+
+ Returns:
+ np.ndarray, shape (C, Th, 3), containing the true positives, false
+ positives, and false negatives for each class and IoU threshold.
+ """
+
+ num_thresholds = sorted_matches.shape[1]
+ num_classes = unique_classes.shape[0]
+
+ confusion_matrix = np.zeros((num_classes, num_thresholds, 3))
+ for class_idx, class_id in enumerate(unique_classes):
+ is_class = sorted_prediction_class_ids == class_id
+ num_true = class_counts[class_idx]
+ num_predictions = is_class.sum()
+
+ if num_predictions == 0:
+ true_positives = np.zeros(num_thresholds)
+ false_positives = np.zeros(num_thresholds)
+ false_negatives = np.full(num_thresholds, num_true)
+ elif num_true == 0:
+ true_positives = np.zeros(num_thresholds)
+ false_positives = np.full(num_thresholds, num_predictions)
+ false_negatives = np.zeros(num_thresholds)
+ else:
+ true_positives = sorted_matches[is_class].sum(0)
+ false_positives = (1 - sorted_matches[is_class]).sum(0)
+ false_negatives = num_true - true_positives
+ confusion_matrix[class_idx] = np.stack(
+ [true_positives, false_positives, false_negatives], axis=1
+ )
+
+ return confusion_matrix
+
+ @staticmethod
+ def _compute_precision(confusion_matrix: np.ndarray) -> np.ndarray:
+ """
+ Broadcastable function, computing the precision from the confusion matrix.
+
+ Arguments:
+ confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension
+ contains the true positives, false positives, and false negatives.
+
+ Returns:
+ np.ndarray, shape (N, ...), containing the precision for each element.
+ """
+ if not confusion_matrix.shape[-1] == 3:
+ raise ValueError(
+ f"Confusion matrix must have shape (..., 3), got "
+ f"{confusion_matrix.shape}"
+ )
+ true_positives = confusion_matrix[..., 0]
+ false_positives = confusion_matrix[..., 1]
+
+ denominator = true_positives + false_positives
+ precision = np.where(denominator == 0, 0, true_positives / denominator)
+
+ return precision
+
+ def _detections_content(self, detections: Detections) -> np.ndarray:
+ """Return boxes, masks or oriented bounding boxes from detections."""
+ if self._metric_target == MetricTarget.BOXES:
+ return detections.xyxy
+ if self._metric_target == MetricTarget.MASKS:
+ return (
+ detections.mask
+ if detections.mask is not None
+ else np.empty((0, 0, 0), dtype=bool)
+ )
+ if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
+ if obb := detections.data.get(ORIENTED_BOX_COORDINATES):
+ return np.ndarray(obb, dtype=np.float32)
+ return np.empty((0, 8), dtype=np.float32)
+ raise ValueError(f"Invalid metric target: {self._metric_target}")
+
+ def _filter_detections_by_size(
+ self, detections: Detections, size_category: ObjectSizeCategory
+ ) -> Detections:
+ """Return a copy of detections with contents filtered by object size."""
+ new_detections = deepcopy(detections)
+ if detections.is_empty() or size_category == ObjectSizeCategory.ANY:
+ return new_detections
+
+ sizes = get_detection_size_category(new_detections, self._metric_target)
+ size_mask = sizes == size_category.value
+
+ new_detections.xyxy = new_detections.xyxy[size_mask]
+ if new_detections.mask is not None:
+ new_detections.mask = new_detections.mask[size_mask]
+ if new_detections.class_id is not None:
+ new_detections.class_id = new_detections.class_id[size_mask]
+ if new_detections.confidence is not None:
+ new_detections.confidence = new_detections.confidence[size_mask]
+ if new_detections.tracker_id is not None:
+ new_detections.tracker_id = new_detections.tracker_id[size_mask]
+ if new_detections.data is not None:
+ for key, value in new_detections.data.items():
+ new_detections.data[key] = np.array(value)[size_mask]
+
+ return new_detections
+
+ def _filter_predictions_and_targets_by_size(
+ self,
+ predictions_list: List[Detections],
+ targets_list: List[Detections],
+ size_category: ObjectSizeCategory,
+ ) -> Tuple[List[Detections], List[Detections]]:
+ """
+ Filter predictions and targets by object size category.
+ """
+ new_predictions_list = []
+ new_targets_list = []
+ for predictions, targets in zip(predictions_list, targets_list):
+ new_predictions_list.append(
+ self._filter_detections_by_size(predictions, size_category)
+ )
+ new_targets_list.append(
+ self._filter_detections_by_size(targets, size_category)
+ )
+ return new_predictions_list, new_targets_list
+
+
+@dataclass
+class PrecisionResult:
+ """
+ The results of the precision metric calculation.
+
+ Defaults to `0` if no detections or targets were provided.
+
+ Attributes:
+ metric_target (MetricTarget): the type of data used for the metric -
+ boxes, masks or oriented bounding boxes.
+ averaging_method (AveragingMethod): the averaging method used to compute the
+ precision. Determines how the precision is aggregated across classes.
+ precision_at_50 (float): the precision at IoU threshold of `0.5`.
+ precision_at_75 (float): the precision at IoU threshold of `0.75`.
+ precision_scores (np.ndarray): the precision scores at each IoU threshold.
+ Shape: `(num_iou_thresholds,)`
+ precision_per_class (np.ndarray): the precision scores per class and
+ IoU threshold. Shape: `(num_target_classes, num_iou_thresholds)`
+ iou_thresholds (np.ndarray): the IoU thresholds used in the calculations.
+ matched_classes (np.ndarray): the class IDs of all matched classes.
+ Corresponds to the rows of `precision_per_class`.
+ small_objects (Optional[PrecisionResult]): the Precision metric results
+ for small objects.
+ medium_objects (Optional[PrecisionResult]): the Precision metric results
+ for medium objects.
+ large_objects (Optional[PrecisionResult]): the Precision metric results
+ for large objects.
+ """
+
+ metric_target: MetricTarget
+ averaging_method: AveragingMethod
+
+ @property
+ def precision_at_50(self) -> float:
+ return self.precision_scores[0]
+
+ @property
+ def precision_at_75(self) -> float:
+ return self.precision_scores[5]
+
+ precision_scores: np.ndarray
+ precision_per_class: np.ndarray
+ iou_thresholds: np.ndarray
+ matched_classes: np.ndarray
+
+ small_objects: Optional[PrecisionResult]
+ medium_objects: Optional[PrecisionResult]
+ large_objects: Optional[PrecisionResult]
+
+ def __str__(self) -> str:
+ """
+ Format as a pretty string.
+
+ Example:
+ ```python
+ print(precision_result)
+ ```
+ """
+ out_str = (
+ f"{self.__class__.__name__}:\n"
+ f"Metric target: {self.metric_target}\n"
+ f"Averaging method: {self.averaging_method}\n"
+ f"P @ 50: {self.precision_at_50:.4f}\n"
+ f"P @ 75: {self.precision_at_75:.4f}\n"
+ f"P @ thresh: {self.precision_scores}\n"
+ f"IoU thresh: {self.iou_thresholds}\n"
+ f"Precision per class:\n"
+ )
+ if self.precision_per_class.size == 0:
+ out_str += " No results\n"
+ for class_id, precision_of_class in zip(
+ self.matched_classes, self.precision_per_class
+ ):
+ out_str += f" {class_id}: {precision_of_class}\n"
+
+ indent = " "
+ if self.small_objects is not None:
+ indented = indent + str(self.small_objects).replace("\n", f"\n{indent}")
+ out_str += f"\nSmall objects:\n{indented}"
+ if self.medium_objects is not None:
+ indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}")
+ out_str += f"\nMedium objects:\n{indented}"
+ if self.large_objects is not None:
+ indented = indent + str(self.large_objects).replace("\n", f"\n{indent}")
+ out_str += f"\nLarge objects:\n{indented}"
+
+ return out_str
+
+ def to_pandas(self) -> "pd.DataFrame":
+ """
+ Convert the result to a pandas DataFrame.
+
+ Returns:
+ (pd.DataFrame): The result as a DataFrame.
+ """
+ ensure_pandas_installed()
+ import pandas as pd
+
+ pandas_data = {
+ "P@50": self.precision_at_50,
+ "P@75": self.precision_at_75,
+ }
+
+ if self.small_objects is not None:
+ small_objects_df = self.small_objects.to_pandas()
+ for key, value in small_objects_df.items():
+ pandas_data[f"small_objects_{key}"] = value
+ if self.medium_objects is not None:
+ medium_objects_df = self.medium_objects.to_pandas()
+ for key, value in medium_objects_df.items():
+ pandas_data[f"medium_objects_{key}"] = value
+ if self.large_objects is not None:
+ large_objects_df = self.large_objects.to_pandas()
+ for key, value in large_objects_df.items():
+ pandas_data[f"large_objects_{key}"] = value
+
+ return pd.DataFrame(pandas_data, index=[0])
+
+ def plot(self):
+ """
+ Plot the precision results.
+ """
+
+ labels = ["Precision@50", "Precision@75"]
+ values = [self.precision_at_50, self.precision_at_75]
+ colors = [LEGACY_COLOR_PALETTE[0]] * 2
+
+ if self.small_objects is not None:
+ small_objects = self.small_objects
+ labels += ["Small: P@50", "Small: P@75"]
+ values += [small_objects.precision_at_50, small_objects.precision_at_75]
+ colors += [LEGACY_COLOR_PALETTE[3]] * 2
+
+ if self.medium_objects is not None:
+ medium_objects = self.medium_objects
+ labels += ["Medium: P@50", "Medium: P@75"]
+ values += [medium_objects.precision_at_50, medium_objects.precision_at_75]
+ colors += [LEGACY_COLOR_PALETTE[2]] * 2
+
+ if self.large_objects is not None:
+ large_objects = self.large_objects
+ labels += ["Large: P@50", "Large: P@75"]
+ values += [large_objects.precision_at_50, large_objects.precision_at_75]
+ colors += [LEGACY_COLOR_PALETTE[4]] * 2
+
+ plt.rcParams["font.family"] = "monospace"
+
+ _, ax = plt.subplots(figsize=(10, 6))
+ ax.set_ylim(0, 1)
+ ax.set_ylabel("Value", fontweight="bold")
+ title = (
+ f"Precision, by Object Size"
+ f"\n(target: {self.metric_target.value},"
+ f" averaging: {self.averaging_method.value})"
+ )
+ ax.set_title(title, fontweight="bold")
+
+ x_positions = range(len(labels))
+ bars = ax.bar(x_positions, values, color=colors, align="center")
+
+ ax.set_xticks(x_positions)
+ ax.set_xticklabels(labels, rotation=45, ha="right")
+
+ for bar in bars:
+ y_value = bar.get_height()
+ ax.text(
+ bar.get_x() + bar.get_width() / 2,
+ y_value + 0.02,
+ f"{y_value:.2f}",
+ ha="center",
+ va="bottom",
+ )
+
+ plt.rcParams["font.family"] = "sans-serif"
+
+ plt.tight_layout()
+ plt.show()
diff --git a/supervision/metrics/recall.py b/supervision/metrics/recall.py
new file mode 100644
index 000000000..9eae24f8e
--- /dev/null
+++ b/supervision/metrics/recall.py
@@ -0,0 +1,601 @@
+from __future__ import annotations
+
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+
+import numpy as np
+from matplotlib import pyplot as plt
+
+from supervision.config import ORIENTED_BOX_COORDINATES
+from supervision.detection.core import Detections
+from supervision.detection.utils import box_iou_batch, mask_iou_batch
+from supervision.draw.color import LEGACY_COLOR_PALETTE
+from supervision.metrics.core import AveragingMethod, Metric, MetricTarget
+from supervision.metrics.utils.object_size import (
+ ObjectSizeCategory,
+ get_detection_size_category,
+)
+from supervision.metrics.utils.utils import ensure_pandas_installed
+
+if TYPE_CHECKING:
+ import pandas as pd
+
+
+class Recall(Metric):
+ """
+ Recall is a metric used to evaluate object detection models. It is the ratio of
+ true positive detections to the total number of ground truth instances. We calculate
+ it at different IoU thresholds.
+
+ In simple terms, Recall is a measure of a model's completeness, calculated as:
+
+ `Recall = TP / (TP + FN)`
+
+ Here, `TP` is the number of true positives (correct detections), and `FN` is the
+ number of false negatives (missed detections).
+
+ Example:
+ ```python
+ import supervision as sv
+ from supervision.metrics import Recall
+
+ predictions = sv.Detections(...)
+ targets = sv.Detections(...)
+
+ recall_metric = Recall()
+ recall_result = recall_metric.update(predictions, targets).compute()
+
+ print(recall_result)
+ print(recall_result.recall_at_50)
+ print(recall_result.small_objects.recall_at_50)
+ ```
+ """
+
+ def __init__(
+ self,
+ metric_target: MetricTarget = MetricTarget.BOXES,
+ averaging_method: AveragingMethod = AveragingMethod.WEIGHTED,
+ ):
+ """
+ Initialize the Recall metric.
+
+ Args:
+ metric_target (MetricTarget): The type of detection data to use.
+ averaging_method (AveragingMethod): The averaging method used to compute the
+ recall. Determines how the recall is aggregated across classes.
+ """
+ self._metric_target = metric_target
+ if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
+ raise NotImplementedError(
+ "Recall is not implemented for oriented bounding boxes."
+ )
+
+ self._metric_target = metric_target
+ self.averaging_method = averaging_method
+ self._predictions_list: List[Detections] = []
+ self._targets_list: List[Detections] = []
+
+ def reset(self) -> None:
+ """
+ Reset the metric to its initial state, clearing all stored data.
+ """
+ self._predictions_list = []
+ self._targets_list = []
+
+ def update(
+ self,
+ predictions: Union[Detections, List[Detections]],
+ targets: Union[Detections, List[Detections]],
+ ) -> Recall:
+ """
+ Add new predictions and targets to the metric, but do not compute the result.
+
+ Args:
+ predictions (Union[Detections, List[Detections]]): The predicted detections.
+ targets (Union[Detections, List[Detections]]): The target detections.
+
+ Returns:
+ (Recall): The updated metric instance.
+ """
+ if not isinstance(predictions, list):
+ predictions = [predictions]
+ if not isinstance(targets, list):
+ targets = [targets]
+
+ if len(predictions) != len(targets):
+ raise ValueError(
+ f"The number of predictions ({len(predictions)}) and"
+ f" targets ({len(targets)}) during the update must be the same."
+ )
+
+ self._predictions_list.extend(predictions)
+ self._targets_list.extend(targets)
+
+ return self
+
+ def compute(self) -> RecallResult:
+ """
+ Calculate the precision metric based on the stored predictions and ground-truth
+ data, at different IoU thresholds.
+
+ Returns:
+ (RecallResult): The precision metric result.
+ """
+ result = self._compute(self._predictions_list, self._targets_list)
+
+ small_predictions, small_targets = self._filter_predictions_and_targets_by_size(
+ self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL
+ )
+ result.small_objects = self._compute(small_predictions, small_targets)
+
+ medium_predictions, medium_targets = (
+ self._filter_predictions_and_targets_by_size(
+ self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM
+ )
+ )
+ result.medium_objects = self._compute(medium_predictions, medium_targets)
+
+ large_predictions, large_targets = self._filter_predictions_and_targets_by_size(
+ self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE
+ )
+ result.large_objects = self._compute(large_predictions, large_targets)
+
+ return result
+
+ def _compute(
+ self, predictions_list: List[Detections], targets_list: List[Detections]
+ ) -> RecallResult:
+ iou_thresholds = np.linspace(0.5, 0.95, 10)
+ stats = []
+
+ for predictions, targets in zip(predictions_list, targets_list):
+ prediction_contents = self._detections_content(predictions)
+ target_contents = self._detections_content(targets)
+
+ if len(targets) > 0:
+ if len(predictions) == 0:
+ stats.append(
+ (
+ np.zeros((0, iou_thresholds.size), dtype=bool),
+ np.zeros((0,), dtype=np.float32),
+ np.zeros((0,), dtype=int),
+ targets.class_id,
+ )
+ )
+
+ else:
+ if self._metric_target == MetricTarget.BOXES:
+ iou = box_iou_batch(target_contents, prediction_contents)
+ elif self._metric_target == MetricTarget.MASKS:
+ iou = mask_iou_batch(target_contents, prediction_contents)
+ else:
+ raise NotImplementedError(
+ "Unsupported metric target for IoU calculation"
+ )
+
+ matches = self._match_detection_batch(
+ predictions.class_id, targets.class_id, iou, iou_thresholds
+ )
+ stats.append(
+ (
+ matches,
+ predictions.confidence,
+ predictions.class_id,
+ targets.class_id,
+ )
+ )
+
+ if not stats:
+ return RecallResult(
+ metric_target=self._metric_target,
+ averaging_method=self.averaging_method,
+ recall_scores=np.zeros(iou_thresholds.shape[0]),
+ recall_per_class=np.zeros((0, iou_thresholds.shape[0])),
+ iou_thresholds=iou_thresholds,
+ matched_classes=np.array([], dtype=int),
+ small_objects=None,
+ medium_objects=None,
+ large_objects=None,
+ )
+
+ concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
+ recall_scores, recall_per_class, unique_classes = (
+ self._compute_recall_for_classes(*concatenated_stats)
+ )
+
+ return RecallResult(
+ metric_target=self._metric_target,
+ averaging_method=self.averaging_method,
+ recall_scores=recall_scores,
+ recall_per_class=recall_per_class,
+ iou_thresholds=iou_thresholds,
+ matched_classes=unique_classes,
+ small_objects=None,
+ medium_objects=None,
+ large_objects=None,
+ )
+
+ def _compute_recall_for_classes(
+ self,
+ matches: np.ndarray,
+ prediction_confidence: np.ndarray,
+ prediction_class_ids: np.ndarray,
+ true_class_ids: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ sorted_indices = np.argsort(-prediction_confidence)
+ matches = matches[sorted_indices]
+ prediction_class_ids = prediction_class_ids[sorted_indices]
+ unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
+
+ # Shape: PxTh,P,C,C -> CxThx3
+ confusion_matrix = self._compute_confusion_matrix(
+ matches, prediction_class_ids, unique_classes, class_counts
+ )
+
+ # Shape: CxThx3 -> CxTh
+ recall_per_class = self._compute_recall(confusion_matrix)
+
+ # Shape: CxTh -> Th
+ if self.averaging_method == AveragingMethod.MACRO:
+ recall_scores = np.mean(recall_per_class, axis=0)
+ elif self.averaging_method == AveragingMethod.MICRO:
+ confusion_matrix_merged = confusion_matrix.sum(0)
+ recall_scores = self._compute_recall(confusion_matrix_merged)
+ elif self.averaging_method == AveragingMethod.WEIGHTED:
+ class_counts = class_counts.astype(np.float32)
+ recall_scores = np.average(recall_per_class, axis=0, weights=class_counts)
+
+ return recall_scores, recall_per_class, unique_classes
+
+ @staticmethod
+ def _match_detection_batch(
+ predictions_classes: np.ndarray,
+ target_classes: np.ndarray,
+ iou: np.ndarray,
+ iou_thresholds: np.ndarray,
+ ) -> np.ndarray:
+ num_predictions, num_iou_levels = (
+ predictions_classes.shape[0],
+ iou_thresholds.shape[0],
+ )
+ correct = np.zeros((num_predictions, num_iou_levels), dtype=bool)
+ correct_class = target_classes[:, None] == predictions_classes
+
+ for i, iou_level in enumerate(iou_thresholds):
+ matched_indices = np.where((iou >= iou_level) & correct_class)
+
+ if matched_indices[0].shape[0]:
+ combined_indices = np.stack(matched_indices, axis=1)
+ iou_values = iou[matched_indices][:, None]
+ matches = np.hstack([combined_indices, iou_values])
+
+ if matched_indices[0].shape[0] > 1:
+ matches = matches[matches[:, 2].argsort()[::-1]]
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
+
+ correct[matches[:, 1].astype(int), i] = True
+
+ return correct
+
+ @staticmethod
+ def _compute_confusion_matrix(
+ sorted_matches: np.ndarray,
+ sorted_prediction_class_ids: np.ndarray,
+ unique_classes: np.ndarray,
+ class_counts: np.ndarray,
+ ) -> np.ndarray:
+ """
+ Compute the confusion matrix for each class and IoU threshold.
+
+ Assumes the matches and prediction_class_ids are sorted by confidence
+ in descending order.
+
+ Arguments:
+ sorted_matches: np.ndarray, bool, shape (P, Th), that is True
+ if the prediction is a true positive at the given IoU threshold.
+ sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing
+ the class id for each prediction.
+ unique_classes: np.ndarray, int, shape (C,), containing the unique
+ class ids.
+ class_counts: np.ndarray, int, shape (C,), containing the number
+ of true instances for each class.
+
+ Returns:
+ np.ndarray, shape (C, Th, 3), containing the true positives, false
+ positives, and false negatives for each class and IoU threshold.
+ """
+
+ num_thresholds = sorted_matches.shape[1]
+ num_classes = unique_classes.shape[0]
+
+ confusion_matrix = np.zeros((num_classes, num_thresholds, 3))
+ for class_idx, class_id in enumerate(unique_classes):
+ is_class = sorted_prediction_class_ids == class_id
+ num_true = class_counts[class_idx]
+ num_predictions = is_class.sum()
+
+ if num_predictions == 0:
+ true_positives = np.zeros(num_thresholds)
+ false_positives = np.zeros(num_thresholds)
+ false_negatives = np.full(num_thresholds, num_true)
+ elif num_true == 0:
+ true_positives = np.zeros(num_thresholds)
+ false_positives = np.full(num_thresholds, num_predictions)
+ false_negatives = np.zeros(num_thresholds)
+ else:
+ true_positives = sorted_matches[is_class].sum(0)
+ false_positives = (1 - sorted_matches[is_class]).sum(0)
+ false_negatives = num_true - true_positives
+ confusion_matrix[class_idx] = np.stack(
+ [true_positives, false_positives, false_negatives], axis=1
+ )
+
+ return confusion_matrix
+
+ @staticmethod
+ def _compute_recall(confusion_matrix: np.ndarray) -> np.ndarray:
+ """
+ Broadcastable function, computing the recall from the confusion matrix.
+
+ Arguments:
+ confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension
+ contains the true positives, false positives, and false negatives.
+
+ Returns:
+ np.ndarray, shape (N, ...), containing the recall for each element.
+ """
+ if not confusion_matrix.shape[-1] == 3:
+ raise ValueError(
+ f"Confusion matrix must have shape (..., 3), got "
+ f"{confusion_matrix.shape}"
+ )
+ true_positives = confusion_matrix[..., 0]
+ false_negatives = confusion_matrix[..., 2]
+
+ denominator = true_positives + false_negatives
+ recall = np.where(denominator == 0, 0, true_positives / denominator)
+
+ return recall
+
+ def _detections_content(self, detections: Detections) -> np.ndarray:
+ """Return boxes, masks or oriented bounding boxes from detections."""
+ if self._metric_target == MetricTarget.BOXES:
+ return detections.xyxy
+ if self._metric_target == MetricTarget.MASKS:
+ return (
+ detections.mask
+ if detections.mask is not None
+ else np.empty((0, 0, 0), dtype=bool)
+ )
+ if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES:
+ if obb := detections.data.get(ORIENTED_BOX_COORDINATES):
+ return np.ndarray(obb, dtype=np.float32)
+ return np.empty((0, 8), dtype=np.float32)
+ raise ValueError(f"Invalid metric target: {self._metric_target}")
+
+ def _filter_detections_by_size(
+ self, detections: Detections, size_category: ObjectSizeCategory
+ ) -> Detections:
+ """Return a copy of detections with contents filtered by object size."""
+ new_detections = deepcopy(detections)
+ if detections.is_empty() or size_category == ObjectSizeCategory.ANY:
+ return new_detections
+
+ sizes = get_detection_size_category(new_detections, self._metric_target)
+ size_mask = sizes == size_category.value
+
+ new_detections.xyxy = new_detections.xyxy[size_mask]
+ if new_detections.mask is not None:
+ new_detections.mask = new_detections.mask[size_mask]
+ if new_detections.class_id is not None:
+ new_detections.class_id = new_detections.class_id[size_mask]
+ if new_detections.confidence is not None:
+ new_detections.confidence = new_detections.confidence[size_mask]
+ if new_detections.tracker_id is not None:
+ new_detections.tracker_id = new_detections.tracker_id[size_mask]
+ if new_detections.data is not None:
+ for key, value in new_detections.data.items():
+ new_detections.data[key] = np.array(value)[size_mask]
+
+ return new_detections
+
+ def _filter_predictions_and_targets_by_size(
+ self,
+ predictions_list: List[Detections],
+ targets_list: List[Detections],
+ size_category: ObjectSizeCategory,
+ ) -> Tuple[List[Detections], List[Detections]]:
+ """
+ Filter predictions and targets by object size category.
+ """
+ new_predictions_list = []
+ new_targets_list = []
+ for predictions, targets in zip(predictions_list, targets_list):
+ new_predictions_list.append(
+ self._filter_detections_by_size(predictions, size_category)
+ )
+ new_targets_list.append(
+ self._filter_detections_by_size(targets, size_category)
+ )
+ return new_predictions_list, new_targets_list
+
+
+@dataclass
+class RecallResult:
+ """
+ The results of the recall metric calculation.
+
+ Defaults to `0` if no detections or targets were provided.
+
+ Attributes:
+ metric_target (MetricTarget): the type of data used for the metric -
+ boxes, masks or oriented bounding boxes.
+ averaging_method (AveragingMethod): the averaging method used to compute the
+ recall. Determines how the recall is aggregated across classes.
+ recall_at_50 (float): the recall at IoU threshold of `0.5`.
+ recall_at_75 (float): the recall at IoU threshold of `0.75`.
+ recall_scores (np.ndarray): the recall scores at each IoU threshold.
+ Shape: `(num_iou_thresholds,)`
+ recall_per_class (np.ndarray): the recall scores per class and IoU threshold.
+ Shape: `(num_target_classes, num_iou_thresholds)`
+ iou_thresholds (np.ndarray): the IoU thresholds used in the calculations.
+ matched_classes (np.ndarray): the class IDs of all matched classes.
+ Corresponds to the rows of `recall_per_class`.
+ small_objects (Optional[RecallResult]): the Recall metric results
+ for small objects.
+ medium_objects (Optional[RecallResult]): the Recall metric results
+ for medium objects.
+ large_objects (Optional[RecallResult]): the Recall metric results
+ for large objects.
+ """
+
+ metric_target: MetricTarget
+ averaging_method: AveragingMethod
+
+ @property
+ def recall_at_50(self) -> float:
+ return self.recall_scores[0]
+
+ @property
+ def recall_at_75(self) -> float:
+ return self.recall_scores[5]
+
+ recall_scores: np.ndarray
+ recall_per_class: np.ndarray
+ iou_thresholds: np.ndarray
+ matched_classes: np.ndarray
+
+ small_objects: Optional[RecallResult]
+ medium_objects: Optional[RecallResult]
+ large_objects: Optional[RecallResult]
+
+ def __str__(self) -> str:
+ """
+ Format as a pretty string.
+
+ Example:
+ ```python
+ print(recall_result)
+ ```
+ """
+ out_str = (
+ f"{self.__class__.__name__}:\n"
+ f"Metric target: {self.metric_target}\n"
+ f"Averaging method: {self.averaging_method}\n"
+ f"R @ 50: {self.recall_at_50:.4f}\n"
+ f"R @ 75: {self.recall_at_75:.4f}\n"
+ f"R @ thresh: {self.recall_scores}\n"
+ f"IoU thresh: {self.iou_thresholds}\n"
+ f"Recall per class:\n"
+ )
+ if self.recall_per_class.size == 0:
+ out_str += " No results\n"
+ for class_id, recall_of_class in zip(
+ self.matched_classes, self.recall_per_class
+ ):
+ out_str += f" {class_id}: {recall_of_class}\n"
+
+ indent = " "
+ if self.small_objects is not None:
+ indented = indent + str(self.small_objects).replace("\n", f"\n{indent}")
+ out_str += f"\nSmall objects:\n{indented}"
+ if self.medium_objects is not None:
+ indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}")
+ out_str += f"\nMedium objects:\n{indented}"
+ if self.large_objects is not None:
+ indented = indent + str(self.large_objects).replace("\n", f"\n{indent}")
+ out_str += f"\nLarge objects:\n{indented}"
+
+ return out_str
+
+ def to_pandas(self) -> "pd.DataFrame":
+ """
+ Convert the result to a pandas DataFrame.
+
+ Returns:
+ (pd.DataFrame): The result as a DataFrame.
+ """
+ ensure_pandas_installed()
+ import pandas as pd
+
+ pandas_data = {
+ "R@50": self.recall_at_50,
+ "R@75": self.recall_at_75,
+ }
+
+ if self.small_objects is not None:
+ small_objects_df = self.small_objects.to_pandas()
+ for key, value in small_objects_df.items():
+ pandas_data[f"small_objects_{key}"] = value
+ if self.medium_objects is not None:
+ medium_objects_df = self.medium_objects.to_pandas()
+ for key, value in medium_objects_df.items():
+ pandas_data[f"medium_objects_{key}"] = value
+ if self.large_objects is not None:
+ large_objects_df = self.large_objects.to_pandas()
+ for key, value in large_objects_df.items():
+ pandas_data[f"large_objects_{key}"] = value
+
+ return pd.DataFrame(pandas_data, index=[0])
+
+ def plot(self):
+ """
+ Plot the recall results.
+ """
+
+ labels = ["Recall@50", "Recall@75"]
+ values = [self.recall_at_50, self.recall_at_75]
+ colors = [LEGACY_COLOR_PALETTE[0]] * 2
+
+ if self.small_objects is not None:
+ small_objects = self.small_objects
+ labels += ["Small: R@50", "Small: R@75"]
+ values += [small_objects.recall_at_50, small_objects.recall_at_75]
+ colors += [LEGACY_COLOR_PALETTE[3]] * 2
+
+ if self.medium_objects is not None:
+ medium_objects = self.medium_objects
+ labels += ["Medium: R@50", "Medium: R@75"]
+ values += [medium_objects.recall_at_50, medium_objects.recall_at_75]
+ colors += [LEGACY_COLOR_PALETTE[2]] * 2
+
+ if self.large_objects is not None:
+ large_objects = self.large_objects
+ labels += ["Large: R@50", "Large: R@75"]
+ values += [large_objects.recall_at_50, large_objects.recall_at_75]
+ colors += [LEGACY_COLOR_PALETTE[4]] * 2
+
+ plt.rcParams["font.family"] = "monospace"
+
+ _, ax = plt.subplots(figsize=(10, 6))
+ ax.set_ylim(0, 1)
+ ax.set_ylabel("Value", fontweight="bold")
+ title = (
+ f"Recall, by Object Size"
+ f"\n(target: {self.metric_target.value},"
+ f" averaging: {self.averaging_method.value})"
+ )
+ ax.set_title(title, fontweight="bold")
+
+ x_positions = range(len(labels))
+ bars = ax.bar(x_positions, values, color=colors, align="center")
+
+ ax.set_xticks(x_positions)
+ ax.set_xticklabels(labels, rotation=45, ha="right")
+
+ for bar in bars:
+ y_value = bar.get_height()
+ ax.text(
+ bar.get_x() + bar.get_width() / 2,
+ y_value + 0.02,
+ f"{y_value:.2f}",
+ ha="center",
+ va="bottom",
+ )
+
+ plt.rcParams["font.family"] = "sans-serif"
+
+ plt.tight_layout()
+ plt.show()
diff --git a/supervision/tracker/byte_tracker/basetrack.py b/supervision/tracker/byte_tracker/basetrack.py
index 806f75384..e2bc66d57 100644
--- a/supervision/tracker/byte_tracker/basetrack.py
+++ b/supervision/tracker/byte_tracker/basetrack.py
@@ -12,8 +12,6 @@ class TrackState(Enum):
class BaseTrack:
- _count = 0
-
def __init__(self):
self.track_id = 0
self.is_activated = False
@@ -34,20 +32,13 @@ def __init__(self):
def end_frame(self) -> int:
return self.frame_id
- @staticmethod
- def next_id() -> int:
- BaseTrack._count += 1
- return BaseTrack._count
-
- @staticmethod
- def reset_counter():
- BaseTrack._count = 0
- BaseTrack.track_id = 0
- BaseTrack.start_frame = 0
- BaseTrack.frame_id = 0
- BaseTrack.time_since_update = 0
+ def reset_counter(self):
+ self.track_id = 0
+ self.start_frame = 0
+ self.frame_id = 0
+ self.time_since_update = 0
- def activate(self, *args):
+ def activate(self, *args, **kwargs):
raise NotImplementedError
def predict(self):
diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py
index 89e1e2f2c..d1d567fab 100644
--- a/supervision/tracker/byte_tracker/core.py
+++ b/supervision/tracker/byte_tracker/core.py
@@ -9,11 +9,35 @@
from supervision.tracker.byte_tracker.kalman_filter import KalmanFilter
+class IdCounter:
+ def __init__(self):
+ self.reset()
+
+ def reset(self) -> None:
+ self._id = self.NO_ID
+
+ def new_id(self) -> int:
+ self._id += 1
+ return self._id
+
+ @property
+ def NO_ID(self) -> int:
+ return 0
+
+
class STrack(BaseTrack):
shared_kalman = KalmanFilter()
- _external_count = 0
- def __init__(self, tlwh, score, class_ids, minimum_consecutive_frames):
+ def __init__(
+ self,
+ tlwh,
+ score,
+ class_ids,
+ minimum_consecutive_frames,
+ internal_id_counter: IdCounter,
+ external_id_counter: IdCounter,
+ ):
+ super().__init__()
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float32)
self.kalman_filter = None
@@ -24,10 +48,13 @@ def __init__(self, tlwh, score, class_ids, minimum_consecutive_frames):
self.class_ids = class_ids
self.tracklet_len = 0
- self.external_track_id = -1
-
self.minimum_consecutive_frames = minimum_consecutive_frames
+ self.internal_id_counter = internal_id_counter
+ self.external_id_counter = external_id_counter
+ self.internal_track_id = self.internal_id_counter.NO_ID
+ self.external_track_id = self.external_id_counter.NO_ID
+
def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
@@ -57,7 +84,7 @@ def multi_predict(stracks):
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
- self.internal_track_id = self.next_id()
+ self.internal_track_id = self.internal_id_counter.new_id()
self.mean, self.covariance = self.kalman_filter.initiate(
self.tlwh_to_xyah(self._tlwh)
)
@@ -68,12 +95,12 @@ def activate(self, kalman_filter, frame_id):
self.is_activated = True
if self.minimum_consecutive_frames == 1:
- self.external_track_id = self.next_external_id()
+ self.external_track_id = self.external_id_counter.new_id()
self.frame_id = frame_id
self.start_frame = frame_id
- def re_activate(self, new_track, frame_id, new_id=False):
+ def re_activate(self, new_track, frame_id):
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
)
@@ -81,8 +108,6 @@ def re_activate(self, new_track, frame_id, new_id=False):
self.state = TrackState.Tracked
self.frame_id = frame_id
- if new_id:
- self.internal_track_id = self.next_id()
self.score = new_track.score
def update(self, new_track, frame_id):
@@ -103,8 +128,8 @@ def update(self, new_track, frame_id):
self.state = TrackState.Tracked
if self.tracklet_len == self.minimum_consecutive_frames:
self.is_activated = True
- if self.external_track_id == -1:
- self.external_track_id = self.next_external_id()
+ if self.external_track_id == self.external_id_counter.NO_ID:
+ self.external_track_id = self.external_id_counter.new_id()
self.score = new_track.score
@@ -142,15 +167,6 @@ def tlwh_to_xyah(tlwh):
def to_xyah(self):
return self.tlwh_to_xyah(self.tlwh)
- @staticmethod
- def next_external_id():
- STrack._external_count += 1
- return STrack._external_count
-
- @staticmethod
- def reset_external_counter():
- STrack._external_count = 0
-
@staticmethod
def tlbr_to_tlwh(tlbr):
ret = np.asarray(tlbr).copy()
@@ -169,24 +185,6 @@ def __repr__(self):
)
-def detections2boxes(detections: Detections) -> np.ndarray:
- """
- Convert Supervision Detections to numpy tensors for further computation.
- Args:
- detections (Detections): Detections/Targets in the format of sv.Detections.
- Returns:
- (np.ndarray): Detections as numpy tensors as in
- `(x_min, y_min, x_max, y_max, confidence, class_id)` order.
- """
- return np.hstack(
- (
- detections.xyxy,
- detections.confidence[:, np.newaxis],
- detections.class_id[:, np.newaxis],
- )
- )
-
-
class ByteTrack:
"""
Initialize the ByteTrack object.
@@ -235,6 +233,9 @@ def __init__(
self.lost_tracks: List[STrack] = []
self.removed_tracks: List[STrack] = []
+ self.internal_id_counter = IdCounter()
+ self.external_id_counter = IdCounter()
+
def update_with_detections(self, detections: Detections) -> Detections:
"""
Updates the tracker with the provided detections and returns the updated
@@ -275,7 +276,13 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray:
```
"""
- tensors = detections2boxes(detections=detections)
+ tensors = np.hstack(
+ (
+ detections.xyxy,
+ detections.confidence[:, np.newaxis],
+ detections.class_id[:, np.newaxis],
+ )
+ )
tracks = self.update_with_tensors(tensors=tensors)
if len(tracks) > 0:
@@ -311,11 +318,12 @@ def reset(self):
ensuring the tracker starts with a clean state for each new video.
"""
self.frame_id = 0
+ BaseTrack.reset_counter()
+ self.internal_id_counter.reset()
+ self.external_id_counter.reset()
self.tracked_tracks: List[STrack] = []
self.lost_tracks: List[STrack] = []
self.removed_tracks: List[STrack] = []
- BaseTrack.reset_counter()
- STrack.reset_external_counter()
def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
"""
@@ -353,7 +361,14 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
if len(dets) > 0:
"""Detections"""
detections = [
- STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames)
+ STrack(
+ STrack.tlbr_to_tlwh(tlbr),
+ s,
+ c,
+ self.minimum_consecutive_frames,
+ self.internal_id_counter,
+ self.external_id_counter,
+ )
for (tlbr, s, c) in zip(dets, scores_keep, class_ids_keep)
]
else:
@@ -387,7 +402,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
track.update(detections[idet], self.frame_id)
activated_starcks.append(track)
else:
- track.re_activate(det, self.frame_id, new_id=False)
+ track.re_activate(det, self.frame_id)
refind_stracks.append(track)
""" Step 3: Second association, with low score detection boxes"""
@@ -395,7 +410,14 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
if len(dets_second) > 0:
"""Detections"""
detections_second = [
- STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames)
+ STrack(
+ STrack.tlbr_to_tlwh(tlbr),
+ s,
+ c,
+ self.minimum_consecutive_frames,
+ self.internal_id_counter,
+ self.external_id_counter,
+ )
for (tlbr, s, c) in zip(dets_second, scores_second, class_ids_second)
]
else:
@@ -416,7 +438,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
track.update(det, self.frame_id)
activated_starcks.append(track)
else:
- track.re_activate(det, self.frame_id, new_id=False)
+ track.re_activate(det, self.frame_id)
refind_stracks.append(track)
for it in u_track:
diff --git a/supervision/tracker/byte_tracker/matching.py b/supervision/tracker/byte_tracker/matching.py
index 24abe224c..f791e518f 100644
--- a/supervision/tracker/byte_tracker/matching.py
+++ b/supervision/tracker/byte_tracker/matching.py
@@ -20,7 +20,7 @@ def indices_to_matches(
def linear_assignment(
cost_matrix: np.ndarray, thresh: float
-) -> [np.ndarray, Tuple[int], Tuple[int, int]]:
+) -> Tuple[np.ndarray, Tuple[int], Tuple[int, int]]:
if cost_matrix.size == 0:
return (
np.empty((0, 2), dtype=int),
diff --git a/test/tracker/__init__.py b/test/tracker/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/test/tracker/test_byte_tracker.py b/test/tracker/test_byte_tracker.py
new file mode 100644
index 000000000..98efeb093
--- /dev/null
+++ b/test/tracker/test_byte_tracker.py
@@ -0,0 +1,40 @@
+from typing import List
+
+import numpy as np
+import pytest
+
+import supervision as sv
+
+
+@pytest.mark.parametrize(
+ "detections, expected_results",
+ [
+ (
+ [
+ sv.Detections(
+ xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]),
+ class_id=np.array([1, 1]),
+ confidence=np.array([1, 1]),
+ ),
+ sv.Detections(
+ xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]),
+ class_id=np.array([1, 1]),
+ confidence=np.array([1, 1]),
+ ),
+ ],
+ sv.Detections(
+ xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]),
+ class_id=np.array([1, 1]),
+ confidence=np.array([1, 1]),
+ tracker_id=np.array([1, 2]),
+ ),
+ ),
+ ],
+)
+def test_byte_tracker(
+ detections: List[sv.Detections],
+ expected_results: sv.Detections,
+) -> None:
+ byte_tracker = sv.ByteTrack()
+ tracked_detections = [byte_tracker.update_with_detections(d) for d in detections]
+ assert tracked_detections[-1] == expected_results