diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ca4357cb..8ddaa08e9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: additional_dependencies: ["bandit[toml]"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.0 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/docs/metrics/common_values.md b/docs/metrics/common_values.md new file mode 100644 index 000000000..b7600f3f1 --- /dev/null +++ b/docs/metrics/common_values.md @@ -0,0 +1,20 @@ +--- +comments: true +status: new +--- + +# Common Values + +This page contains supplementary values, types and enums that metrics use. + +
+

MetricTarget

+
+ +:::supervision.metrics.core.MetricTarget + +
+

AveragingMethod

+
+ +:::supervision.metrics.core.AveragingMethod diff --git a/docs/metrics/precision.md b/docs/metrics/precision.md new file mode 100644 index 000000000..c704452ee --- /dev/null +++ b/docs/metrics/precision.md @@ -0,0 +1,18 @@ +--- +comments: true +status: new +--- + +# Precision + +
+

Precision

+
+ +:::supervision.metrics.precision.Precision + +
+

PrecisionResult

+
+ +:::supervision.metrics.precision.PrecisionResult diff --git a/docs/metrics/recall.md b/docs/metrics/recall.md new file mode 100644 index 000000000..78dde8334 --- /dev/null +++ b/docs/metrics/recall.md @@ -0,0 +1,18 @@ +--- +comments: true +status: new +--- + +# Recall + +
+

Recall

+
+ +:::supervision.metrics.recall.Recall + +
+

RecallResult

+
+ +:::supervision.metrics.recall.RecallResult diff --git a/mkdocs.yml b/mkdocs.yml index 3cd867590..b30dbcfcc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -66,7 +66,10 @@ nav: - Utils: datasets/utils.md - Metrics: - mAP: metrics/mean_average_precision.md + - Precision: metrics/precision.md + - Recall: metrics/recall.md - F1 Score: metrics/f1_score.md + - Common Values: metrics/common_values.md - Legacy Metrics: detection/metrics.md - Utils: - Video: utils/video.md diff --git a/poetry.lock b/poetry.lock index 797b71397..0a010aa4d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2371,13 +2371,13 @@ requests = "*" [[package]] name = "mkdocs-git-revision-date-localized-plugin" -version = "1.2.9" +version = "1.3.0" description = "Mkdocs plugin that enables displaying the localized date of the last git modification of a markdown file." optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_git_revision_date_localized_plugin-1.2.9-py3-none-any.whl", hash = "sha256:dea5c8067c23df30275702a1708885500fadf0abfb595b60e698bffc79c7a423"}, - {file = "mkdocs_git_revision_date_localized_plugin-1.2.9.tar.gz", hash = "sha256:df9a50873fba3a42ce9123885f8c53d589e90ef6c2443fe3280ef1e8d33c8f65"}, + {file = "mkdocs_git_revision_date_localized_plugin-1.3.0-py3-none-any.whl", hash = "sha256:c99377ee119372d57a9e47cff4e68f04cce634a74831c06bc89b33e456e840a1"}, + {file = "mkdocs_git_revision_date_localized_plugin-1.3.0.tar.gz", hash = "sha256:439e2f14582204050a664c258861c325064d97cdc848c541e48bb034a6c4d0cb"}, ] [package.dependencies] @@ -2412,13 +2412,13 @@ pygments = ">2.12.0" [[package]] name = "mkdocs-material" -version = "9.5.41" +version = "9.5.42" description = "Documentation that simply works" optional = false python-versions = ">=3.8" files = [ - {file = "mkdocs_material-9.5.41-py3-none-any.whl", hash = "sha256:990bc138c33342b5b73e7545915ebc0136e501bfbd8e365735144f5120891d83"}, - {file = "mkdocs_material-9.5.41.tar.gz", hash = "sha256:30fa5d459b4b8130848ecd8e1c908878345d9d8268f7ddbc31eebe88d462d97b"}, + {file = "mkdocs_material-9.5.42-py3-none-any.whl", hash = "sha256:452a7c5d21284b373f36b981a2cbebfff59263feebeede1bc28652e9c5bbe316"}, + {file = "mkdocs_material-9.5.42.tar.gz", hash = "sha256:92779b5e9b5934540c574c11647131d217dc540dce72b05feeda088c8eb1b8f2"}, ] [package.dependencies] @@ -2509,43 +2509,43 @@ files = [ [[package]] name = "mypy" -version = "1.12.0" +version = "1.13.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4397081e620dc4dc18e2f124d5e1d2c288194c2c08df6bdb1db31c38cd1fe1ed"}, - {file = "mypy-1.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:684a9c508a283f324804fea3f0effeb7858eb03f85c4402a967d187f64562469"}, - {file = "mypy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cabe4cda2fa5eca7ac94854c6c37039324baaa428ecbf4de4567279e9810f9e"}, - {file = "mypy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:060a07b10e999ac9e7fa249ce2bdcfa9183ca2b70756f3bce9df7a92f78a3c0a"}, - {file = "mypy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:0eff042d7257f39ba4ca06641d110ca7d2ad98c9c1fb52200fe6b1c865d360ff"}, - {file = "mypy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b86de37a0da945f6d48cf110d5206c5ed514b1ca2614d7ad652d4bf099c7de7"}, - {file = "mypy-1.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20c7c5ce0c1be0b0aea628374e6cf68b420bcc772d85c3c974f675b88e3e6e57"}, - {file = "mypy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64ee25f05fc2d3d8474985c58042b6759100a475f8237da1f4faf7fcd7e6309"}, - {file = "mypy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:faca7ab947c9f457a08dcb8d9a8664fd438080e002b0fa3e41b0535335edcf7f"}, - {file = "mypy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:5bc81701d52cc8767005fdd2a08c19980de9ec61a25dbd2a937dfb1338a826f9"}, - {file = "mypy-1.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8462655b6694feb1c99e433ea905d46c478041a8b8f0c33f1dab00ae881b2164"}, - {file = "mypy-1.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:923ea66d282d8af9e0f9c21ffc6653643abb95b658c3a8a32dca1eff09c06475"}, - {file = "mypy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1ebf9e796521f99d61864ed89d1fb2926d9ab6a5fab421e457cd9c7e4dd65aa9"}, - {file = "mypy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e478601cc3e3fa9d6734d255a59c7a2e5c2934da4378f3dd1e3411ea8a248642"}, - {file = "mypy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:c72861b7139a4f738344faa0e150834467521a3fba42dc98264e5aa9507dd601"}, - {file = "mypy-1.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52b9e1492e47e1790360a43755fa04101a7ac72287b1a53ce817f35899ba0521"}, - {file = "mypy-1.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:48d3e37dd7d9403e38fa86c46191de72705166d40b8c9f91a3de77350daa0893"}, - {file = "mypy-1.12.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f106db5ccb60681b622ac768455743ee0e6a857724d648c9629a9bd2ac3f721"}, - {file = "mypy-1.12.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:233e11b3f73ee1f10efada2e6da0f555b2f3a5316e9d8a4a1224acc10e7181d3"}, - {file = "mypy-1.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:4ae8959c21abcf9d73aa6c74a313c45c0b5a188752bf37dace564e29f06e9c1b"}, - {file = "mypy-1.12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:eafc1b7319b40ddabdc3db8d7d48e76cfc65bbeeafaa525a4e0fa6b76175467f"}, - {file = "mypy-1.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9b9ce1ad8daeb049c0b55fdb753d7414260bad8952645367e70ac91aec90e07e"}, - {file = "mypy-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfe012b50e1491d439172c43ccb50db66d23fab714d500b57ed52526a1020bb7"}, - {file = "mypy-1.12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2c40658d4fa1ab27cb53d9e2f1066345596af2f8fe4827defc398a09c7c9519b"}, - {file = "mypy-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:dee78a8b9746c30c1e617ccb1307b351ded57f0de0d287ca6276378d770006c0"}, - {file = "mypy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b5df6c8a8224f6b86746bda716bbe4dbe0ce89fd67b1fa4661e11bfe38e8ec8"}, - {file = "mypy-1.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5feee5c74eb9749e91b77f60b30771563327329e29218d95bedbe1257e2fe4b0"}, - {file = "mypy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:77278e8c6ffe2abfba6db4125de55f1024de9a323be13d20e4f73b8ed3402bd1"}, - {file = "mypy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:dcfb754dea911039ac12434d1950d69a2f05acd4d56f7935ed402be09fad145e"}, - {file = "mypy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:06de0498798527451ffb60f68db0d368bd2bae2bbfb5237eae616d4330cc87aa"}, - {file = "mypy-1.12.0-py3-none-any.whl", hash = "sha256:fd313226af375d52e1e36c383f39bf3836e1f192801116b31b090dfcd3ec5266"}, - {file = "mypy-1.12.0.tar.gz", hash = "sha256:65a22d87e757ccd95cbbf6f7e181e6caa87128255eb2b6be901bb71b26d8a99d"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, ] [package.dependencies] @@ -2555,6 +2555,7 @@ typing-extensions = ">=4.6.0" [package.extras] dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] install-types = ["pip"] mypyc = ["setuptools (>=50)"] reports = ["lxml"] @@ -4036,29 +4037,29 @@ files = [ [[package]] name = "ruff" -version = "0.6.9" +version = "0.7.0" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.9-py3-none-linux_armv6l.whl", hash = "sha256:064df58d84ccc0ac0fcd63bc3090b251d90e2a372558c0f057c3f75ed73e1ccd"}, - {file = "ruff-0.6.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:140d4b5c9f5fc7a7b074908a78ab8d384dd7f6510402267bc76c37195c02a7ec"}, - {file = "ruff-0.6.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53fd8ca5e82bdee8da7f506d7b03a261f24cd43d090ea9db9a1dc59d9313914c"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645d7d8761f915e48a00d4ecc3686969761df69fb561dd914a773c1a8266e14e"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eae02b700763e3847595b9d2891488989cac00214da7f845f4bcf2989007d577"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d5ccc9e58112441de8ad4b29dcb7a86dc25c5f770e3c06a9d57e0e5eba48829"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:417b81aa1c9b60b2f8edc463c58363075412866ae4e2b9ab0f690dc1e87ac1b5"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c866b631f5fbce896a74a6e4383407ba7507b815ccc52bcedabb6810fdb3ef7"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b118afbb3202f5911486ad52da86d1d52305b59e7ef2031cea3425142b97d6f"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67267654edc23c97335586774790cde402fb6bbdb3c2314f1fc087dee320bfa"}, - {file = "ruff-0.6.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3ef0cc774b00fec123f635ce5c547dac263f6ee9fb9cc83437c5904183b55ceb"}, - {file = "ruff-0.6.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:12edd2af0c60fa61ff31cefb90aef4288ac4d372b4962c2864aeea3a1a2460c0"}, - {file = "ruff-0.6.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:55bb01caeaf3a60b2b2bba07308a02fca6ab56233302406ed5245180a05c5625"}, - {file = "ruff-0.6.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:925d26471fa24b0ce5a6cdfab1bb526fb4159952385f386bdcc643813d472039"}, - {file = "ruff-0.6.9-py3-none-win32.whl", hash = "sha256:eb61ec9bdb2506cffd492e05ac40e5bc6284873aceb605503d8494180d6fc84d"}, - {file = "ruff-0.6.9-py3-none-win_amd64.whl", hash = "sha256:785d31851c1ae91f45b3d8fe23b8ae4b5170089021fbb42402d811135f0b7117"}, - {file = "ruff-0.6.9-py3-none-win_arm64.whl", hash = "sha256:a9641e31476d601f83cd602608739a0840e348bda93fec9f1ee816f8b6798b93"}, - {file = "ruff-0.6.9.tar.gz", hash = "sha256:b076ef717a8e5bc819514ee1d602bbdca5b4420ae13a9cf61a0c0a4f53a2baa2"}, + {file = "ruff-0.7.0-py3-none-linux_armv6l.whl", hash = "sha256:0cdf20c2b6ff98e37df47b2b0bd3a34aaa155f59a11182c1303cce79be715628"}, + {file = "ruff-0.7.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:496494d350c7fdeb36ca4ef1c9f21d80d182423718782222c29b3e72b3512737"}, + {file = "ruff-0.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:214b88498684e20b6b2b8852c01d50f0651f3cc6118dfa113b4def9f14faaf06"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630fce3fefe9844e91ea5bbf7ceadab4f9981f42b704fae011bb8efcaf5d84be"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:211d877674e9373d4bb0f1c80f97a0201c61bcd1e9d045b6e9726adc42c156aa"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:194d6c46c98c73949a106425ed40a576f52291c12bc21399eb8f13a0f7073495"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:82c2579b82b9973a110fab281860403b397c08c403de92de19568f32f7178598"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9af971fe85dcd5eaed8f585ddbc6bdbe8c217fb8fcf510ea6bca5bdfff56040e"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b641c7f16939b7d24b7bfc0be4102c56562a18281f84f635604e8a6989948914"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d71672336e46b34e0c90a790afeac8a31954fd42872c1f6adaea1dff76fd44f9"}, + {file = "ruff-0.7.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ab7d98c7eed355166f367597e513a6c82408df4181a937628dbec79abb2a1fe4"}, + {file = "ruff-0.7.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1eb54986f770f49edb14f71d33312d79e00e629a57387382200b1ef12d6a4ef9"}, + {file = "ruff-0.7.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:dc452ba6f2bb9cf8726a84aa877061a2462afe9ae0ea1d411c53d226661c601d"}, + {file = "ruff-0.7.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4b406c2dce5be9bad59f2de26139a86017a517e6bcd2688da515481c05a2cb11"}, + {file = "ruff-0.7.0-py3-none-win32.whl", hash = "sha256:f6c968509f767776f524a8430426539587d5ec5c662f6addb6aa25bc2e8195ec"}, + {file = "ruff-0.7.0-py3-none-win_amd64.whl", hash = "sha256:ff4aabfbaaba880e85d394603b9e75d32b0693152e16fa659a3064a85df7fce2"}, + {file = "ruff-0.7.0-py3-none-win_arm64.whl", hash = "sha256:10842f69c245e78d6adec7e1db0a7d9ddc2fff0621d730e61657b64fa36f207e"}, + {file = "ruff-0.7.0.tar.gz", hash = "sha256:47a86360cf62d9cd53ebfb0b5eb0e882193fc191c6d717e8bef4462bc3b9ea2b"}, ] [[package]] @@ -4377,13 +4378,13 @@ files = [ [[package]] name = "tox" -version = "4.22.0" +version = "4.23.2" description = "tox is a generic virtualenv management and test command line tool" optional = false python-versions = ">=3.8" files = [ - {file = "tox-4.22.0-py3-none-any.whl", hash = "sha256:03734d9a9ac138cd1a898a372fb1b8079e2728618ae06dc37cbf3686cfb56eea"}, - {file = "tox-4.22.0.tar.gz", hash = "sha256:acc6c627cb3316585238d55d2b633e132fea1bdb01b9d93b56bce7caea6ae73d"}, + {file = "tox-4.23.2-py3-none-any.whl", hash = "sha256:452bc32bb031f2282881a2118923176445bac783ab97c874b8770ab4c3b76c38"}, + {file = "tox-4.23.2.tar.gz", hash = "sha256:86075e00e555df6e82e74cfc333917f91ecb47ffbc868dcafbd2672e332f4a2c"}, ] [package.dependencies] @@ -4399,6 +4400,9 @@ tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} typing-extensions = {version = ">=4.12.2", markers = "python_version < \"3.11\""} virtualenv = ">=20.26.6" +[package.extras] +test = ["devpi-process (>=1.0.2)", "pytest (>=8.3.3)", "pytest-mock (>=3.14)"] + [[package]] name = "tqdm" version = "4.66.5" diff --git a/supervision/detection/tools/polygon_zone.py b/supervision/detection/tools/polygon_zone.py index b8b728fc5..f69f3c9fe 100644 --- a/supervision/detection/tools/polygon_zone.py +++ b/supervision/detection/tools/polygon_zone.py @@ -17,6 +17,12 @@ class PolygonZone: """ A class for defining a polygon-shaped zone within a frame for detecting objects. + !!! warning + + LineZone uses the `tracker_id`. Read + [here](/latest/trackers/) to learn how to plug + tracking into your inference pipeline. + Attributes: polygon (np.ndarray): A polygon represented by a numpy array of shape `(N, 2)`, containing the `x`, `y` coordinates of the points. @@ -26,6 +32,28 @@ class PolygonZone: (default: (sv.Position.BOTTOM_CENTER,)). current_count (int): The current count of detected objects within the zone mask (np.ndarray): The 2D bool mask for the polygon zone + + Example: + ```python + import supervision as sv + from ultralytics import YOLO + import numpy as np + import cv2 + + image = cv2.imread() + model = YOLO("yolo11s") + tracker = sv.ByteTrack() + + polygon = np.array([[100, 200], [200, 100], [300, 200], [200, 300]]) + polygon_zone = sv.PolygonZone(polygon=polygon) + + result = model.infer(image)[0] + detections = sv.Detections.from_ultralytics(result) + detections = tracker.update_with_detections(detections) + + is_detections_in_zone = polygon_zone.trigger(detections) + print(polygon_zone.current_count) + ``` """ def __init__( diff --git a/supervision/geometry/utils.py b/supervision/geometry/utils.py index 8a0ca35c5..2247adc50 100644 --- a/supervision/geometry/utils.py +++ b/supervision/geometry/utils.py @@ -16,6 +16,9 @@ def get_polygon_center(polygon: np.ndarray) -> Point: Point: The center of the polygon, represented as a Point object with x and y attributes. + Raises: + ValueError: If the polygon has no vertices. + Examples: ```python import numpy as np @@ -30,6 +33,9 @@ def get_polygon_center(polygon: np.ndarray) -> Point: # This is one of the 3 candidate algorithms considered for centroid calculation. # For a more detailed discussion, see PR #1084 and commit eb33176 + if len(polygon) == 0: + raise ValueError("Polygon must have at least one vertex.") + shift_polygon = np.roll(polygon, -1, axis=0) signed_areas = np.cross(polygon, shift_polygon) / 2 if signed_areas.sum() == 0: diff --git a/supervision/metrics/__init__.py b/supervision/metrics/__init__.py index 8ae33e639..90fc17b47 100644 --- a/supervision/metrics/__init__.py +++ b/supervision/metrics/__init__.py @@ -8,6 +8,8 @@ MeanAveragePrecision, MeanAveragePrecisionResult, ) +from supervision.metrics.precision import Precision, PrecisionResult +from supervision.metrics.recall import Recall, RecallResult from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, diff --git a/supervision/metrics/core.py b/supervision/metrics/core.py index d1818441e..def5999a0 100644 --- a/supervision/metrics/core.py +++ b/supervision/metrics/core.py @@ -37,9 +37,10 @@ class MetricTarget(Enum): """ Specifies what type of detection is used to compute the metric. - * BOXES: xyxy bounding boxes - * MASKS: Binary masks - * ORIENTED_BOUNDING_BOXES: Oriented bounding boxes (OBB) + Attributes: + BOXES: xyxy bounding boxes + MASKS: Binary masks + ORIENTED_BOUNDING_BOXES: Oriented bounding boxes (OBB) """ BOXES = "boxes" @@ -54,15 +55,16 @@ class AveragingMethod(Enum): Suppose, before returning the final result, a metric is computed for each class. How do you combine those to get the final number? - * MACRO: Calculate the metric for each class and average the results. The simplest - averaging method, but it does not take class imbalance into account. - * MICRO: Calculate the metric globally by counting the total true positives, false - positives, and false negatives. Micro averaging is useful when you want to give - more importance to classes with more samples. It's also more appropriate if you - have an imbalance in the number of instances per class. - * WEIGHTED: Calculate the metric for each class and average the results, weighted by - the number of true instances of each class. Use weighted averaging if you want - to take class imbalance into account. + Attributes: + MACRO: Calculate the metric for each class and average the results. The simplest + averaging method, but it does not take class imbalance into account. + MICRO: Calculate the metric globally by counting the total true positives, false + positives, and false negatives. Micro averaging is useful when you want to + give more importance to classes with more samples. It's also more + appropriate if you have an imbalance in the number of instances per class. + WEIGHTED: Calculate the metric for each class and average the results, weighted + by the number of true instances of each class. Use weighted averaging if + you want to take class imbalance into account. """ MACRO = "macro" diff --git a/supervision/metrics/f1_score.py b/supervision/metrics/f1_score.py index 2ca5bca5c..ba4fcd59a 100644 --- a/supervision/metrics/f1_score.py +++ b/supervision/metrics/f1_score.py @@ -23,11 +23,45 @@ class F1Score(Metric): + """ + F1 Score is a metric used to evaluate object detection models. It is the harmonic + mean of precision and recall, calculated at different IoU thresholds. + + In simple terms, F1 Score is a measure of a model's balance between precision and + recall (accuracy and completeness), calculated as: + + `F1 = 2 * (precision * recall) / (precision + recall)` + + Example: + ```python + import supervision as sv + from supervision.metrics import F1Score + + predictions = sv.Detections(...) + targets = sv.Detections(...) + + f1_metric = F1Score() + f1_result = f1_metric.update(predictions, targets).compute() + + print(f1_result) + print(f1_result.f1_50) + print(f1_result.small_objects.f1_50) + ``` + """ + def __init__( self, metric_target: MetricTarget = MetricTarget.BOXES, averaging_method: AveragingMethod = AveragingMethod.WEIGHTED, ): + """ + Initialize the F1Score metric. + + Args: + metric_target (MetricTarget): The type of detection data to use. + averaging_method (AveragingMethod): The averaging method used to compute the + F1 scores. Determines how the F1 scores are aggregated across classes. + """ self._metric_target = metric_target if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: raise NotImplementedError( @@ -40,6 +74,9 @@ def __init__( self._targets_list: List[Detections] = [] def reset(self) -> None: + """ + Reset the metric to its initial state, clearing all stored data. + """ self._predictions_list = [] self._targets_list = [] @@ -48,6 +85,16 @@ def update( predictions: Union[Detections, List[Detections]], targets: Union[Detections, List[Detections]], ) -> F1Score: + """ + Add new predictions and targets to the metric, but do not compute the result. + + Args: + predictions (Union[Detections, List[Detections]]): The predicted detections. + targets (Union[Detections, List[Detections]]): The target detections. + + Returns: + (F1Score): The updated metric instance. + """ if not isinstance(predictions, list): predictions = [predictions] if not isinstance(targets, list): @@ -65,6 +112,13 @@ def update( return self def compute(self) -> F1ScoreResult: + """ + Calculate the F1 score metric based on the stored predictions and ground-truth + data, at different IoU thresholds. + + Returns: + (F1ScoreResult): The F1 score metric result. + """ result = self._compute(self._predictions_list, self._targets_list) small_predictions, small_targets = self._filter_predictions_and_targets_by_size( @@ -373,7 +427,6 @@ class F1ScoreResult: The results of the F1 score metric calculation. Defaults to `0` if no detections or targets were provided. - Provides a custom `__str__` method for pretty printing. Attributes: metric_target (MetricTarget): the type of data used for the metric - diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index dbd60b2e7..8cec50c85 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -23,6 +23,27 @@ class MeanAveragePrecision(Metric): + """ + Mean Average Precision (mAP) is a metric used to evaluate object detection models. + It is the average of the precision-recall curves at different IoU thresholds. + + Example: + ```python + import supervision as sv + from supervision.metrics import MeanAveragePrecision + + predictions = sv.Detections(...) + targets = sv.Detections(...) + + map_metric = MeanAveragePrecision() + map_result = map_metric.update(predictions, targets).compute() + + print(map_result) + print(map_result.map50_95) + map_result.plot() + ``` + """ + def __init__( self, metric_target: MetricTarget = MetricTarget.BOXES, @@ -47,6 +68,9 @@ def __init__( self._targets_list: List[Detections] = [] def reset(self) -> None: + """ + Reset the metric to its initial state, clearing all stored data. + """ self._predictions_list = [] self._targets_list = [] @@ -95,26 +119,10 @@ def compute( ) -> MeanAveragePrecisionResult: """ Calculate Mean Average Precision based on predicted and ground-truth - detections at different thresholds. + detections at different thresholds. Returns: - (MeanAveragePrecisionResult): New instance of MeanAveragePrecision. - - Example: - ```python - import supervision as sv - from supervision.metrics import MeanAveragePrecision - - predictions = sv.Detections(...) - targets = sv.Detections(...) - - map_metric = MeanAveragePrecision() - map_result = map_metric.update(predictions, targets).compute() - - print(map_result) - print(map_result.map50_95) - map_result.plot() - ``` + (MeanAveragePrecisionResult): The Mean Average Precision result. """ result = self._compute(self._predictions_list, self._targets_list) diff --git a/supervision/metrics/precision.py b/supervision/metrics/precision.py new file mode 100644 index 000000000..d915e1f49 --- /dev/null +++ b/supervision/metrics/precision.py @@ -0,0 +1,603 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np +from matplotlib import pyplot as plt + +from supervision.config import ORIENTED_BOX_COORDINATES +from supervision.detection.core import Detections +from supervision.detection.utils import box_iou_batch, mask_iou_batch +from supervision.draw.color import LEGACY_COLOR_PALETTE +from supervision.metrics.core import AveragingMethod, Metric, MetricTarget +from supervision.metrics.utils.object_size import ( + ObjectSizeCategory, + get_detection_size_category, +) +from supervision.metrics.utils.utils import ensure_pandas_installed + +if TYPE_CHECKING: + import pandas as pd + + +class Precision(Metric): + """ + Precision is a metric used to evaluate object detection models. It is the ratio of + true positive detections to the total number of predicted detections. We calculate + it at different IoU thresholds. + + In simple terms, Precision is a measure of a model's accuracy, calculated as: + + `Precision = TP / (TP + FP)` + + Here, `TP` is the number of true positives (correct detections), and `FP` is the + number of false positive detections (detected, but incorrectly). + + Example: + ```python + import supervision as sv + from supervision.metrics import Precision + + predictions = sv.Detections(...) + targets = sv.Detections(...) + + precision_metric = Precision() + precision_result = precision_metric.update(predictions, targets).compute() + + print(precision_result) + print(precision_result.precision_at_50) + print(precision_result.small_objects.precision_at_50) + ``` + """ + + def __init__( + self, + metric_target: MetricTarget = MetricTarget.BOXES, + averaging_method: AveragingMethod = AveragingMethod.WEIGHTED, + ): + """ + Initialize the Precision metric. + + Args: + metric_target (MetricTarget): The type of detection data to use. + averaging_method (AveragingMethod): The averaging method used to compute the + precision. Determines how the precision is aggregated across classes. + """ + self._metric_target = metric_target + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + raise NotImplementedError( + "Precision is not implemented for oriented bounding boxes." + ) + + self._metric_target = metric_target + self.averaging_method = averaging_method + self._predictions_list: List[Detections] = [] + self._targets_list: List[Detections] = [] + + def reset(self) -> None: + """ + Reset the metric to its initial state, clearing all stored data. + """ + self._predictions_list = [] + self._targets_list = [] + + def update( + self, + predictions: Union[Detections, List[Detections]], + targets: Union[Detections, List[Detections]], + ) -> Precision: + """ + Add new predictions and targets to the metric, but do not compute the result. + + Args: + predictions (Union[Detections, List[Detections]]): The predicted detections. + targets (Union[Detections, List[Detections]]): The target detections. + + Returns: + (Precision): The updated metric instance. + """ + if not isinstance(predictions, list): + predictions = [predictions] + if not isinstance(targets, list): + targets = [targets] + + if len(predictions) != len(targets): + raise ValueError( + f"The number of predictions ({len(predictions)}) and" + f" targets ({len(targets)}) during the update must be the same." + ) + + self._predictions_list.extend(predictions) + self._targets_list.extend(targets) + + return self + + def compute(self) -> PrecisionResult: + """ + Calculate the precision metric based on the stored predictions and ground-truth + data, at different IoU thresholds. + + Returns: + (PrecisionResult): The precision metric result. + """ + result = self._compute(self._predictions_list, self._targets_list) + + small_predictions, small_targets = self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL + ) + result.small_objects = self._compute(small_predictions, small_targets) + + medium_predictions, medium_targets = ( + self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM + ) + ) + result.medium_objects = self._compute(medium_predictions, medium_targets) + + large_predictions, large_targets = self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE + ) + result.large_objects = self._compute(large_predictions, large_targets) + + return result + + def _compute( + self, predictions_list: List[Detections], targets_list: List[Detections] + ) -> PrecisionResult: + iou_thresholds = np.linspace(0.5, 0.95, 10) + stats = [] + + for predictions, targets in zip(predictions_list, targets_list): + prediction_contents = self._detections_content(predictions) + target_contents = self._detections_content(targets) + + if len(targets) > 0: + if len(predictions) == 0: + stats.append( + ( + np.zeros((0, iou_thresholds.size), dtype=bool), + np.zeros((0,), dtype=np.float32), + np.zeros((0,), dtype=int), + targets.class_id, + ) + ) + + else: + if self._metric_target == MetricTarget.BOXES: + iou = box_iou_batch(target_contents, prediction_contents) + elif self._metric_target == MetricTarget.MASKS: + iou = mask_iou_batch(target_contents, prediction_contents) + else: + raise NotImplementedError( + "Unsupported metric target for IoU calculation" + ) + + matches = self._match_detection_batch( + predictions.class_id, targets.class_id, iou, iou_thresholds + ) + stats.append( + ( + matches, + predictions.confidence, + predictions.class_id, + targets.class_id, + ) + ) + + if not stats: + return PrecisionResult( + metric_target=self._metric_target, + averaging_method=self.averaging_method, + precision_scores=np.zeros(iou_thresholds.shape[0]), + precision_per_class=np.zeros((0, iou_thresholds.shape[0])), + iou_thresholds=iou_thresholds, + matched_classes=np.array([], dtype=int), + small_objects=None, + medium_objects=None, + large_objects=None, + ) + + concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)] + precision_scores, precision_per_class, unique_classes = ( + self._compute_precision_for_classes(*concatenated_stats) + ) + + return PrecisionResult( + metric_target=self._metric_target, + averaging_method=self.averaging_method, + precision_scores=precision_scores, + precision_per_class=precision_per_class, + iou_thresholds=iou_thresholds, + matched_classes=unique_classes, + small_objects=None, + medium_objects=None, + large_objects=None, + ) + + def _compute_precision_for_classes( + self, + matches: np.ndarray, + prediction_confidence: np.ndarray, + prediction_class_ids: np.ndarray, + true_class_ids: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + sorted_indices = np.argsort(-prediction_confidence) + matches = matches[sorted_indices] + prediction_class_ids = prediction_class_ids[sorted_indices] + unique_classes, class_counts = np.unique(true_class_ids, return_counts=True) + + # Shape: PxTh,P,C,C -> CxThx3 + confusion_matrix = self._compute_confusion_matrix( + matches, prediction_class_ids, unique_classes, class_counts + ) + + # Shape: CxThx3 -> CxTh + precision_per_class = self._compute_precision(confusion_matrix) + + # Shape: CxTh -> Th + if self.averaging_method == AveragingMethod.MACRO: + precision_scores = np.mean(precision_per_class, axis=0) + elif self.averaging_method == AveragingMethod.MICRO: + confusion_matrix_merged = confusion_matrix.sum(0) + precision_scores = self._compute_precision(confusion_matrix_merged) + elif self.averaging_method == AveragingMethod.WEIGHTED: + class_counts = class_counts.astype(np.float32) + precision_scores = np.average( + precision_per_class, axis=0, weights=class_counts + ) + + return precision_scores, precision_per_class, unique_classes + + @staticmethod + def _match_detection_batch( + predictions_classes: np.ndarray, + target_classes: np.ndarray, + iou: np.ndarray, + iou_thresholds: np.ndarray, + ) -> np.ndarray: + num_predictions, num_iou_levels = ( + predictions_classes.shape[0], + iou_thresholds.shape[0], + ) + correct = np.zeros((num_predictions, num_iou_levels), dtype=bool) + correct_class = target_classes[:, None] == predictions_classes + + for i, iou_level in enumerate(iou_thresholds): + matched_indices = np.where((iou >= iou_level) & correct_class) + + if matched_indices[0].shape[0]: + combined_indices = np.stack(matched_indices, axis=1) + iou_values = iou[matched_indices][:, None] + matches = np.hstack([combined_indices, iou_values]) + + if matched_indices[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + + correct[matches[:, 1].astype(int), i] = True + + return correct + + @staticmethod + def _compute_confusion_matrix( + sorted_matches: np.ndarray, + sorted_prediction_class_ids: np.ndarray, + unique_classes: np.ndarray, + class_counts: np.ndarray, + ) -> np.ndarray: + """ + Compute the confusion matrix for each class and IoU threshold. + + Assumes the matches and prediction_class_ids are sorted by confidence + in descending order. + + Arguments: + sorted_matches: np.ndarray, bool, shape (P, Th), that is True + if the prediction is a true positive at the given IoU threshold. + sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing + the class id for each prediction. + unique_classes: np.ndarray, int, shape (C,), containing the unique + class ids. + class_counts: np.ndarray, int, shape (C,), containing the number + of true instances for each class. + + Returns: + np.ndarray, shape (C, Th, 3), containing the true positives, false + positives, and false negatives for each class and IoU threshold. + """ + + num_thresholds = sorted_matches.shape[1] + num_classes = unique_classes.shape[0] + + confusion_matrix = np.zeros((num_classes, num_thresholds, 3)) + for class_idx, class_id in enumerate(unique_classes): + is_class = sorted_prediction_class_ids == class_id + num_true = class_counts[class_idx] + num_predictions = is_class.sum() + + if num_predictions == 0: + true_positives = np.zeros(num_thresholds) + false_positives = np.zeros(num_thresholds) + false_negatives = np.full(num_thresholds, num_true) + elif num_true == 0: + true_positives = np.zeros(num_thresholds) + false_positives = np.full(num_thresholds, num_predictions) + false_negatives = np.zeros(num_thresholds) + else: + true_positives = sorted_matches[is_class].sum(0) + false_positives = (1 - sorted_matches[is_class]).sum(0) + false_negatives = num_true - true_positives + confusion_matrix[class_idx] = np.stack( + [true_positives, false_positives, false_negatives], axis=1 + ) + + return confusion_matrix + + @staticmethod + def _compute_precision(confusion_matrix: np.ndarray) -> np.ndarray: + """ + Broadcastable function, computing the precision from the confusion matrix. + + Arguments: + confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension + contains the true positives, false positives, and false negatives. + + Returns: + np.ndarray, shape (N, ...), containing the precision for each element. + """ + if not confusion_matrix.shape[-1] == 3: + raise ValueError( + f"Confusion matrix must have shape (..., 3), got " + f"{confusion_matrix.shape}" + ) + true_positives = confusion_matrix[..., 0] + false_positives = confusion_matrix[..., 1] + + denominator = true_positives + false_positives + precision = np.where(denominator == 0, 0, true_positives / denominator) + + return precision + + def _detections_content(self, detections: Detections) -> np.ndarray: + """Return boxes, masks or oriented bounding boxes from detections.""" + if self._metric_target == MetricTarget.BOXES: + return detections.xyxy + if self._metric_target == MetricTarget.MASKS: + return ( + detections.mask + if detections.mask is not None + else np.empty((0, 0, 0), dtype=bool) + ) + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + if obb := detections.data.get(ORIENTED_BOX_COORDINATES): + return np.ndarray(obb, dtype=np.float32) + return np.empty((0, 8), dtype=np.float32) + raise ValueError(f"Invalid metric target: {self._metric_target}") + + def _filter_detections_by_size( + self, detections: Detections, size_category: ObjectSizeCategory + ) -> Detections: + """Return a copy of detections with contents filtered by object size.""" + new_detections = deepcopy(detections) + if detections.is_empty() or size_category == ObjectSizeCategory.ANY: + return new_detections + + sizes = get_detection_size_category(new_detections, self._metric_target) + size_mask = sizes == size_category.value + + new_detections.xyxy = new_detections.xyxy[size_mask] + if new_detections.mask is not None: + new_detections.mask = new_detections.mask[size_mask] + if new_detections.class_id is not None: + new_detections.class_id = new_detections.class_id[size_mask] + if new_detections.confidence is not None: + new_detections.confidence = new_detections.confidence[size_mask] + if new_detections.tracker_id is not None: + new_detections.tracker_id = new_detections.tracker_id[size_mask] + if new_detections.data is not None: + for key, value in new_detections.data.items(): + new_detections.data[key] = np.array(value)[size_mask] + + return new_detections + + def _filter_predictions_and_targets_by_size( + self, + predictions_list: List[Detections], + targets_list: List[Detections], + size_category: ObjectSizeCategory, + ) -> Tuple[List[Detections], List[Detections]]: + """ + Filter predictions and targets by object size category. + """ + new_predictions_list = [] + new_targets_list = [] + for predictions, targets in zip(predictions_list, targets_list): + new_predictions_list.append( + self._filter_detections_by_size(predictions, size_category) + ) + new_targets_list.append( + self._filter_detections_by_size(targets, size_category) + ) + return new_predictions_list, new_targets_list + + +@dataclass +class PrecisionResult: + """ + The results of the precision metric calculation. + + Defaults to `0` if no detections or targets were provided. + + Attributes: + metric_target (MetricTarget): the type of data used for the metric - + boxes, masks or oriented bounding boxes. + averaging_method (AveragingMethod): the averaging method used to compute the + precision. Determines how the precision is aggregated across classes. + precision_at_50 (float): the precision at IoU threshold of `0.5`. + precision_at_75 (float): the precision at IoU threshold of `0.75`. + precision_scores (np.ndarray): the precision scores at each IoU threshold. + Shape: `(num_iou_thresholds,)` + precision_per_class (np.ndarray): the precision scores per class and + IoU threshold. Shape: `(num_target_classes, num_iou_thresholds)` + iou_thresholds (np.ndarray): the IoU thresholds used in the calculations. + matched_classes (np.ndarray): the class IDs of all matched classes. + Corresponds to the rows of `precision_per_class`. + small_objects (Optional[PrecisionResult]): the Precision metric results + for small objects. + medium_objects (Optional[PrecisionResult]): the Precision metric results + for medium objects. + large_objects (Optional[PrecisionResult]): the Precision metric results + for large objects. + """ + + metric_target: MetricTarget + averaging_method: AveragingMethod + + @property + def precision_at_50(self) -> float: + return self.precision_scores[0] + + @property + def precision_at_75(self) -> float: + return self.precision_scores[5] + + precision_scores: np.ndarray + precision_per_class: np.ndarray + iou_thresholds: np.ndarray + matched_classes: np.ndarray + + small_objects: Optional[PrecisionResult] + medium_objects: Optional[PrecisionResult] + large_objects: Optional[PrecisionResult] + + def __str__(self) -> str: + """ + Format as a pretty string. + + Example: + ```python + print(precision_result) + ``` + """ + out_str = ( + f"{self.__class__.__name__}:\n" + f"Metric target: {self.metric_target}\n" + f"Averaging method: {self.averaging_method}\n" + f"P @ 50: {self.precision_at_50:.4f}\n" + f"P @ 75: {self.precision_at_75:.4f}\n" + f"P @ thresh: {self.precision_scores}\n" + f"IoU thresh: {self.iou_thresholds}\n" + f"Precision per class:\n" + ) + if self.precision_per_class.size == 0: + out_str += " No results\n" + for class_id, precision_of_class in zip( + self.matched_classes, self.precision_per_class + ): + out_str += f" {class_id}: {precision_of_class}\n" + + indent = " " + if self.small_objects is not None: + indented = indent + str(self.small_objects).replace("\n", f"\n{indent}") + out_str += f"\nSmall objects:\n{indented}" + if self.medium_objects is not None: + indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}") + out_str += f"\nMedium objects:\n{indented}" + if self.large_objects is not None: + indented = indent + str(self.large_objects).replace("\n", f"\n{indent}") + out_str += f"\nLarge objects:\n{indented}" + + return out_str + + def to_pandas(self) -> "pd.DataFrame": + """ + Convert the result to a pandas DataFrame. + + Returns: + (pd.DataFrame): The result as a DataFrame. + """ + ensure_pandas_installed() + import pandas as pd + + pandas_data = { + "P@50": self.precision_at_50, + "P@75": self.precision_at_75, + } + + if self.small_objects is not None: + small_objects_df = self.small_objects.to_pandas() + for key, value in small_objects_df.items(): + pandas_data[f"small_objects_{key}"] = value + if self.medium_objects is not None: + medium_objects_df = self.medium_objects.to_pandas() + for key, value in medium_objects_df.items(): + pandas_data[f"medium_objects_{key}"] = value + if self.large_objects is not None: + large_objects_df = self.large_objects.to_pandas() + for key, value in large_objects_df.items(): + pandas_data[f"large_objects_{key}"] = value + + return pd.DataFrame(pandas_data, index=[0]) + + def plot(self): + """ + Plot the precision results. + """ + + labels = ["Precision@50", "Precision@75"] + values = [self.precision_at_50, self.precision_at_75] + colors = [LEGACY_COLOR_PALETTE[0]] * 2 + + if self.small_objects is not None: + small_objects = self.small_objects + labels += ["Small: P@50", "Small: P@75"] + values += [small_objects.precision_at_50, small_objects.precision_at_75] + colors += [LEGACY_COLOR_PALETTE[3]] * 2 + + if self.medium_objects is not None: + medium_objects = self.medium_objects + labels += ["Medium: P@50", "Medium: P@75"] + values += [medium_objects.precision_at_50, medium_objects.precision_at_75] + colors += [LEGACY_COLOR_PALETTE[2]] * 2 + + if self.large_objects is not None: + large_objects = self.large_objects + labels += ["Large: P@50", "Large: P@75"] + values += [large_objects.precision_at_50, large_objects.precision_at_75] + colors += [LEGACY_COLOR_PALETTE[4]] * 2 + + plt.rcParams["font.family"] = "monospace" + + _, ax = plt.subplots(figsize=(10, 6)) + ax.set_ylim(0, 1) + ax.set_ylabel("Value", fontweight="bold") + title = ( + f"Precision, by Object Size" + f"\n(target: {self.metric_target.value}," + f" averaging: {self.averaging_method.value})" + ) + ax.set_title(title, fontweight="bold") + + x_positions = range(len(labels)) + bars = ax.bar(x_positions, values, color=colors, align="center") + + ax.set_xticks(x_positions) + ax.set_xticklabels(labels, rotation=45, ha="right") + + for bar in bars: + y_value = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2, + y_value + 0.02, + f"{y_value:.2f}", + ha="center", + va="bottom", + ) + + plt.rcParams["font.family"] = "sans-serif" + + plt.tight_layout() + plt.show() diff --git a/supervision/metrics/recall.py b/supervision/metrics/recall.py new file mode 100644 index 000000000..9eae24f8e --- /dev/null +++ b/supervision/metrics/recall.py @@ -0,0 +1,601 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np +from matplotlib import pyplot as plt + +from supervision.config import ORIENTED_BOX_COORDINATES +from supervision.detection.core import Detections +from supervision.detection.utils import box_iou_batch, mask_iou_batch +from supervision.draw.color import LEGACY_COLOR_PALETTE +from supervision.metrics.core import AveragingMethod, Metric, MetricTarget +from supervision.metrics.utils.object_size import ( + ObjectSizeCategory, + get_detection_size_category, +) +from supervision.metrics.utils.utils import ensure_pandas_installed + +if TYPE_CHECKING: + import pandas as pd + + +class Recall(Metric): + """ + Recall is a metric used to evaluate object detection models. It is the ratio of + true positive detections to the total number of ground truth instances. We calculate + it at different IoU thresholds. + + In simple terms, Recall is a measure of a model's completeness, calculated as: + + `Recall = TP / (TP + FN)` + + Here, `TP` is the number of true positives (correct detections), and `FN` is the + number of false negatives (missed detections). + + Example: + ```python + import supervision as sv + from supervision.metrics import Recall + + predictions = sv.Detections(...) + targets = sv.Detections(...) + + recall_metric = Recall() + recall_result = recall_metric.update(predictions, targets).compute() + + print(recall_result) + print(recall_result.recall_at_50) + print(recall_result.small_objects.recall_at_50) + ``` + """ + + def __init__( + self, + metric_target: MetricTarget = MetricTarget.BOXES, + averaging_method: AveragingMethod = AveragingMethod.WEIGHTED, + ): + """ + Initialize the Recall metric. + + Args: + metric_target (MetricTarget): The type of detection data to use. + averaging_method (AveragingMethod): The averaging method used to compute the + recall. Determines how the recall is aggregated across classes. + """ + self._metric_target = metric_target + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + raise NotImplementedError( + "Recall is not implemented for oriented bounding boxes." + ) + + self._metric_target = metric_target + self.averaging_method = averaging_method + self._predictions_list: List[Detections] = [] + self._targets_list: List[Detections] = [] + + def reset(self) -> None: + """ + Reset the metric to its initial state, clearing all stored data. + """ + self._predictions_list = [] + self._targets_list = [] + + def update( + self, + predictions: Union[Detections, List[Detections]], + targets: Union[Detections, List[Detections]], + ) -> Recall: + """ + Add new predictions and targets to the metric, but do not compute the result. + + Args: + predictions (Union[Detections, List[Detections]]): The predicted detections. + targets (Union[Detections, List[Detections]]): The target detections. + + Returns: + (Recall): The updated metric instance. + """ + if not isinstance(predictions, list): + predictions = [predictions] + if not isinstance(targets, list): + targets = [targets] + + if len(predictions) != len(targets): + raise ValueError( + f"The number of predictions ({len(predictions)}) and" + f" targets ({len(targets)}) during the update must be the same." + ) + + self._predictions_list.extend(predictions) + self._targets_list.extend(targets) + + return self + + def compute(self) -> RecallResult: + """ + Calculate the precision metric based on the stored predictions and ground-truth + data, at different IoU thresholds. + + Returns: + (RecallResult): The precision metric result. + """ + result = self._compute(self._predictions_list, self._targets_list) + + small_predictions, small_targets = self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL + ) + result.small_objects = self._compute(small_predictions, small_targets) + + medium_predictions, medium_targets = ( + self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM + ) + ) + result.medium_objects = self._compute(medium_predictions, medium_targets) + + large_predictions, large_targets = self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE + ) + result.large_objects = self._compute(large_predictions, large_targets) + + return result + + def _compute( + self, predictions_list: List[Detections], targets_list: List[Detections] + ) -> RecallResult: + iou_thresholds = np.linspace(0.5, 0.95, 10) + stats = [] + + for predictions, targets in zip(predictions_list, targets_list): + prediction_contents = self._detections_content(predictions) + target_contents = self._detections_content(targets) + + if len(targets) > 0: + if len(predictions) == 0: + stats.append( + ( + np.zeros((0, iou_thresholds.size), dtype=bool), + np.zeros((0,), dtype=np.float32), + np.zeros((0,), dtype=int), + targets.class_id, + ) + ) + + else: + if self._metric_target == MetricTarget.BOXES: + iou = box_iou_batch(target_contents, prediction_contents) + elif self._metric_target == MetricTarget.MASKS: + iou = mask_iou_batch(target_contents, prediction_contents) + else: + raise NotImplementedError( + "Unsupported metric target for IoU calculation" + ) + + matches = self._match_detection_batch( + predictions.class_id, targets.class_id, iou, iou_thresholds + ) + stats.append( + ( + matches, + predictions.confidence, + predictions.class_id, + targets.class_id, + ) + ) + + if not stats: + return RecallResult( + metric_target=self._metric_target, + averaging_method=self.averaging_method, + recall_scores=np.zeros(iou_thresholds.shape[0]), + recall_per_class=np.zeros((0, iou_thresholds.shape[0])), + iou_thresholds=iou_thresholds, + matched_classes=np.array([], dtype=int), + small_objects=None, + medium_objects=None, + large_objects=None, + ) + + concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)] + recall_scores, recall_per_class, unique_classes = ( + self._compute_recall_for_classes(*concatenated_stats) + ) + + return RecallResult( + metric_target=self._metric_target, + averaging_method=self.averaging_method, + recall_scores=recall_scores, + recall_per_class=recall_per_class, + iou_thresholds=iou_thresholds, + matched_classes=unique_classes, + small_objects=None, + medium_objects=None, + large_objects=None, + ) + + def _compute_recall_for_classes( + self, + matches: np.ndarray, + prediction_confidence: np.ndarray, + prediction_class_ids: np.ndarray, + true_class_ids: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + sorted_indices = np.argsort(-prediction_confidence) + matches = matches[sorted_indices] + prediction_class_ids = prediction_class_ids[sorted_indices] + unique_classes, class_counts = np.unique(true_class_ids, return_counts=True) + + # Shape: PxTh,P,C,C -> CxThx3 + confusion_matrix = self._compute_confusion_matrix( + matches, prediction_class_ids, unique_classes, class_counts + ) + + # Shape: CxThx3 -> CxTh + recall_per_class = self._compute_recall(confusion_matrix) + + # Shape: CxTh -> Th + if self.averaging_method == AveragingMethod.MACRO: + recall_scores = np.mean(recall_per_class, axis=0) + elif self.averaging_method == AveragingMethod.MICRO: + confusion_matrix_merged = confusion_matrix.sum(0) + recall_scores = self._compute_recall(confusion_matrix_merged) + elif self.averaging_method == AveragingMethod.WEIGHTED: + class_counts = class_counts.astype(np.float32) + recall_scores = np.average(recall_per_class, axis=0, weights=class_counts) + + return recall_scores, recall_per_class, unique_classes + + @staticmethod + def _match_detection_batch( + predictions_classes: np.ndarray, + target_classes: np.ndarray, + iou: np.ndarray, + iou_thresholds: np.ndarray, + ) -> np.ndarray: + num_predictions, num_iou_levels = ( + predictions_classes.shape[0], + iou_thresholds.shape[0], + ) + correct = np.zeros((num_predictions, num_iou_levels), dtype=bool) + correct_class = target_classes[:, None] == predictions_classes + + for i, iou_level in enumerate(iou_thresholds): + matched_indices = np.where((iou >= iou_level) & correct_class) + + if matched_indices[0].shape[0]: + combined_indices = np.stack(matched_indices, axis=1) + iou_values = iou[matched_indices][:, None] + matches = np.hstack([combined_indices, iou_values]) + + if matched_indices[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + + correct[matches[:, 1].astype(int), i] = True + + return correct + + @staticmethod + def _compute_confusion_matrix( + sorted_matches: np.ndarray, + sorted_prediction_class_ids: np.ndarray, + unique_classes: np.ndarray, + class_counts: np.ndarray, + ) -> np.ndarray: + """ + Compute the confusion matrix for each class and IoU threshold. + + Assumes the matches and prediction_class_ids are sorted by confidence + in descending order. + + Arguments: + sorted_matches: np.ndarray, bool, shape (P, Th), that is True + if the prediction is a true positive at the given IoU threshold. + sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing + the class id for each prediction. + unique_classes: np.ndarray, int, shape (C,), containing the unique + class ids. + class_counts: np.ndarray, int, shape (C,), containing the number + of true instances for each class. + + Returns: + np.ndarray, shape (C, Th, 3), containing the true positives, false + positives, and false negatives for each class and IoU threshold. + """ + + num_thresholds = sorted_matches.shape[1] + num_classes = unique_classes.shape[0] + + confusion_matrix = np.zeros((num_classes, num_thresholds, 3)) + for class_idx, class_id in enumerate(unique_classes): + is_class = sorted_prediction_class_ids == class_id + num_true = class_counts[class_idx] + num_predictions = is_class.sum() + + if num_predictions == 0: + true_positives = np.zeros(num_thresholds) + false_positives = np.zeros(num_thresholds) + false_negatives = np.full(num_thresholds, num_true) + elif num_true == 0: + true_positives = np.zeros(num_thresholds) + false_positives = np.full(num_thresholds, num_predictions) + false_negatives = np.zeros(num_thresholds) + else: + true_positives = sorted_matches[is_class].sum(0) + false_positives = (1 - sorted_matches[is_class]).sum(0) + false_negatives = num_true - true_positives + confusion_matrix[class_idx] = np.stack( + [true_positives, false_positives, false_negatives], axis=1 + ) + + return confusion_matrix + + @staticmethod + def _compute_recall(confusion_matrix: np.ndarray) -> np.ndarray: + """ + Broadcastable function, computing the recall from the confusion matrix. + + Arguments: + confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension + contains the true positives, false positives, and false negatives. + + Returns: + np.ndarray, shape (N, ...), containing the recall for each element. + """ + if not confusion_matrix.shape[-1] == 3: + raise ValueError( + f"Confusion matrix must have shape (..., 3), got " + f"{confusion_matrix.shape}" + ) + true_positives = confusion_matrix[..., 0] + false_negatives = confusion_matrix[..., 2] + + denominator = true_positives + false_negatives + recall = np.where(denominator == 0, 0, true_positives / denominator) + + return recall + + def _detections_content(self, detections: Detections) -> np.ndarray: + """Return boxes, masks or oriented bounding boxes from detections.""" + if self._metric_target == MetricTarget.BOXES: + return detections.xyxy + if self._metric_target == MetricTarget.MASKS: + return ( + detections.mask + if detections.mask is not None + else np.empty((0, 0, 0), dtype=bool) + ) + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + if obb := detections.data.get(ORIENTED_BOX_COORDINATES): + return np.ndarray(obb, dtype=np.float32) + return np.empty((0, 8), dtype=np.float32) + raise ValueError(f"Invalid metric target: {self._metric_target}") + + def _filter_detections_by_size( + self, detections: Detections, size_category: ObjectSizeCategory + ) -> Detections: + """Return a copy of detections with contents filtered by object size.""" + new_detections = deepcopy(detections) + if detections.is_empty() or size_category == ObjectSizeCategory.ANY: + return new_detections + + sizes = get_detection_size_category(new_detections, self._metric_target) + size_mask = sizes == size_category.value + + new_detections.xyxy = new_detections.xyxy[size_mask] + if new_detections.mask is not None: + new_detections.mask = new_detections.mask[size_mask] + if new_detections.class_id is not None: + new_detections.class_id = new_detections.class_id[size_mask] + if new_detections.confidence is not None: + new_detections.confidence = new_detections.confidence[size_mask] + if new_detections.tracker_id is not None: + new_detections.tracker_id = new_detections.tracker_id[size_mask] + if new_detections.data is not None: + for key, value in new_detections.data.items(): + new_detections.data[key] = np.array(value)[size_mask] + + return new_detections + + def _filter_predictions_and_targets_by_size( + self, + predictions_list: List[Detections], + targets_list: List[Detections], + size_category: ObjectSizeCategory, + ) -> Tuple[List[Detections], List[Detections]]: + """ + Filter predictions and targets by object size category. + """ + new_predictions_list = [] + new_targets_list = [] + for predictions, targets in zip(predictions_list, targets_list): + new_predictions_list.append( + self._filter_detections_by_size(predictions, size_category) + ) + new_targets_list.append( + self._filter_detections_by_size(targets, size_category) + ) + return new_predictions_list, new_targets_list + + +@dataclass +class RecallResult: + """ + The results of the recall metric calculation. + + Defaults to `0` if no detections or targets were provided. + + Attributes: + metric_target (MetricTarget): the type of data used for the metric - + boxes, masks or oriented bounding boxes. + averaging_method (AveragingMethod): the averaging method used to compute the + recall. Determines how the recall is aggregated across classes. + recall_at_50 (float): the recall at IoU threshold of `0.5`. + recall_at_75 (float): the recall at IoU threshold of `0.75`. + recall_scores (np.ndarray): the recall scores at each IoU threshold. + Shape: `(num_iou_thresholds,)` + recall_per_class (np.ndarray): the recall scores per class and IoU threshold. + Shape: `(num_target_classes, num_iou_thresholds)` + iou_thresholds (np.ndarray): the IoU thresholds used in the calculations. + matched_classes (np.ndarray): the class IDs of all matched classes. + Corresponds to the rows of `recall_per_class`. + small_objects (Optional[RecallResult]): the Recall metric results + for small objects. + medium_objects (Optional[RecallResult]): the Recall metric results + for medium objects. + large_objects (Optional[RecallResult]): the Recall metric results + for large objects. + """ + + metric_target: MetricTarget + averaging_method: AveragingMethod + + @property + def recall_at_50(self) -> float: + return self.recall_scores[0] + + @property + def recall_at_75(self) -> float: + return self.recall_scores[5] + + recall_scores: np.ndarray + recall_per_class: np.ndarray + iou_thresholds: np.ndarray + matched_classes: np.ndarray + + small_objects: Optional[RecallResult] + medium_objects: Optional[RecallResult] + large_objects: Optional[RecallResult] + + def __str__(self) -> str: + """ + Format as a pretty string. + + Example: + ```python + print(recall_result) + ``` + """ + out_str = ( + f"{self.__class__.__name__}:\n" + f"Metric target: {self.metric_target}\n" + f"Averaging method: {self.averaging_method}\n" + f"R @ 50: {self.recall_at_50:.4f}\n" + f"R @ 75: {self.recall_at_75:.4f}\n" + f"R @ thresh: {self.recall_scores}\n" + f"IoU thresh: {self.iou_thresholds}\n" + f"Recall per class:\n" + ) + if self.recall_per_class.size == 0: + out_str += " No results\n" + for class_id, recall_of_class in zip( + self.matched_classes, self.recall_per_class + ): + out_str += f" {class_id}: {recall_of_class}\n" + + indent = " " + if self.small_objects is not None: + indented = indent + str(self.small_objects).replace("\n", f"\n{indent}") + out_str += f"\nSmall objects:\n{indented}" + if self.medium_objects is not None: + indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}") + out_str += f"\nMedium objects:\n{indented}" + if self.large_objects is not None: + indented = indent + str(self.large_objects).replace("\n", f"\n{indent}") + out_str += f"\nLarge objects:\n{indented}" + + return out_str + + def to_pandas(self) -> "pd.DataFrame": + """ + Convert the result to a pandas DataFrame. + + Returns: + (pd.DataFrame): The result as a DataFrame. + """ + ensure_pandas_installed() + import pandas as pd + + pandas_data = { + "R@50": self.recall_at_50, + "R@75": self.recall_at_75, + } + + if self.small_objects is not None: + small_objects_df = self.small_objects.to_pandas() + for key, value in small_objects_df.items(): + pandas_data[f"small_objects_{key}"] = value + if self.medium_objects is not None: + medium_objects_df = self.medium_objects.to_pandas() + for key, value in medium_objects_df.items(): + pandas_data[f"medium_objects_{key}"] = value + if self.large_objects is not None: + large_objects_df = self.large_objects.to_pandas() + for key, value in large_objects_df.items(): + pandas_data[f"large_objects_{key}"] = value + + return pd.DataFrame(pandas_data, index=[0]) + + def plot(self): + """ + Plot the recall results. + """ + + labels = ["Recall@50", "Recall@75"] + values = [self.recall_at_50, self.recall_at_75] + colors = [LEGACY_COLOR_PALETTE[0]] * 2 + + if self.small_objects is not None: + small_objects = self.small_objects + labels += ["Small: R@50", "Small: R@75"] + values += [small_objects.recall_at_50, small_objects.recall_at_75] + colors += [LEGACY_COLOR_PALETTE[3]] * 2 + + if self.medium_objects is not None: + medium_objects = self.medium_objects + labels += ["Medium: R@50", "Medium: R@75"] + values += [medium_objects.recall_at_50, medium_objects.recall_at_75] + colors += [LEGACY_COLOR_PALETTE[2]] * 2 + + if self.large_objects is not None: + large_objects = self.large_objects + labels += ["Large: R@50", "Large: R@75"] + values += [large_objects.recall_at_50, large_objects.recall_at_75] + colors += [LEGACY_COLOR_PALETTE[4]] * 2 + + plt.rcParams["font.family"] = "monospace" + + _, ax = plt.subplots(figsize=(10, 6)) + ax.set_ylim(0, 1) + ax.set_ylabel("Value", fontweight="bold") + title = ( + f"Recall, by Object Size" + f"\n(target: {self.metric_target.value}," + f" averaging: {self.averaging_method.value})" + ) + ax.set_title(title, fontweight="bold") + + x_positions = range(len(labels)) + bars = ax.bar(x_positions, values, color=colors, align="center") + + ax.set_xticks(x_positions) + ax.set_xticklabels(labels, rotation=45, ha="right") + + for bar in bars: + y_value = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2, + y_value + 0.02, + f"{y_value:.2f}", + ha="center", + va="bottom", + ) + + plt.rcParams["font.family"] = "sans-serif" + + plt.tight_layout() + plt.show() diff --git a/supervision/tracker/byte_tracker/basetrack.py b/supervision/tracker/byte_tracker/basetrack.py index 806f75384..e2bc66d57 100644 --- a/supervision/tracker/byte_tracker/basetrack.py +++ b/supervision/tracker/byte_tracker/basetrack.py @@ -12,8 +12,6 @@ class TrackState(Enum): class BaseTrack: - _count = 0 - def __init__(self): self.track_id = 0 self.is_activated = False @@ -34,20 +32,13 @@ def __init__(self): def end_frame(self) -> int: return self.frame_id - @staticmethod - def next_id() -> int: - BaseTrack._count += 1 - return BaseTrack._count - - @staticmethod - def reset_counter(): - BaseTrack._count = 0 - BaseTrack.track_id = 0 - BaseTrack.start_frame = 0 - BaseTrack.frame_id = 0 - BaseTrack.time_since_update = 0 + def reset_counter(self): + self.track_id = 0 + self.start_frame = 0 + self.frame_id = 0 + self.time_since_update = 0 - def activate(self, *args): + def activate(self, *args, **kwargs): raise NotImplementedError def predict(self): diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index 89e1e2f2c..d1d567fab 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -9,11 +9,35 @@ from supervision.tracker.byte_tracker.kalman_filter import KalmanFilter +class IdCounter: + def __init__(self): + self.reset() + + def reset(self) -> None: + self._id = self.NO_ID + + def new_id(self) -> int: + self._id += 1 + return self._id + + @property + def NO_ID(self) -> int: + return 0 + + class STrack(BaseTrack): shared_kalman = KalmanFilter() - _external_count = 0 - def __init__(self, tlwh, score, class_ids, minimum_consecutive_frames): + def __init__( + self, + tlwh, + score, + class_ids, + minimum_consecutive_frames, + internal_id_counter: IdCounter, + external_id_counter: IdCounter, + ): + super().__init__() # wait activate self._tlwh = np.asarray(tlwh, dtype=np.float32) self.kalman_filter = None @@ -24,10 +48,13 @@ def __init__(self, tlwh, score, class_ids, minimum_consecutive_frames): self.class_ids = class_ids self.tracklet_len = 0 - self.external_track_id = -1 - self.minimum_consecutive_frames = minimum_consecutive_frames + self.internal_id_counter = internal_id_counter + self.external_id_counter = external_id_counter + self.internal_track_id = self.internal_id_counter.NO_ID + self.external_track_id = self.external_id_counter.NO_ID + def predict(self): mean_state = self.mean.copy() if self.state != TrackState.Tracked: @@ -57,7 +84,7 @@ def multi_predict(stracks): def activate(self, kalman_filter, frame_id): """Start a new tracklet""" self.kalman_filter = kalman_filter - self.internal_track_id = self.next_id() + self.internal_track_id = self.internal_id_counter.new_id() self.mean, self.covariance = self.kalman_filter.initiate( self.tlwh_to_xyah(self._tlwh) ) @@ -68,12 +95,12 @@ def activate(self, kalman_filter, frame_id): self.is_activated = True if self.minimum_consecutive_frames == 1: - self.external_track_id = self.next_external_id() + self.external_track_id = self.external_id_counter.new_id() self.frame_id = frame_id self.start_frame = frame_id - def re_activate(self, new_track, frame_id, new_id=False): + def re_activate(self, new_track, frame_id): self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh) ) @@ -81,8 +108,6 @@ def re_activate(self, new_track, frame_id, new_id=False): self.state = TrackState.Tracked self.frame_id = frame_id - if new_id: - self.internal_track_id = self.next_id() self.score = new_track.score def update(self, new_track, frame_id): @@ -103,8 +128,8 @@ def update(self, new_track, frame_id): self.state = TrackState.Tracked if self.tracklet_len == self.minimum_consecutive_frames: self.is_activated = True - if self.external_track_id == -1: - self.external_track_id = self.next_external_id() + if self.external_track_id == self.external_id_counter.NO_ID: + self.external_track_id = self.external_id_counter.new_id() self.score = new_track.score @@ -142,15 +167,6 @@ def tlwh_to_xyah(tlwh): def to_xyah(self): return self.tlwh_to_xyah(self.tlwh) - @staticmethod - def next_external_id(): - STrack._external_count += 1 - return STrack._external_count - - @staticmethod - def reset_external_counter(): - STrack._external_count = 0 - @staticmethod def tlbr_to_tlwh(tlbr): ret = np.asarray(tlbr).copy() @@ -169,24 +185,6 @@ def __repr__(self): ) -def detections2boxes(detections: Detections) -> np.ndarray: - """ - Convert Supervision Detections to numpy tensors for further computation. - Args: - detections (Detections): Detections/Targets in the format of sv.Detections. - Returns: - (np.ndarray): Detections as numpy tensors as in - `(x_min, y_min, x_max, y_max, confidence, class_id)` order. - """ - return np.hstack( - ( - detections.xyxy, - detections.confidence[:, np.newaxis], - detections.class_id[:, np.newaxis], - ) - ) - - class ByteTrack: """ Initialize the ByteTrack object. @@ -235,6 +233,9 @@ def __init__( self.lost_tracks: List[STrack] = [] self.removed_tracks: List[STrack] = [] + self.internal_id_counter = IdCounter() + self.external_id_counter = IdCounter() + def update_with_detections(self, detections: Detections) -> Detections: """ Updates the tracker with the provided detections and returns the updated @@ -275,7 +276,13 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray: ``` """ - tensors = detections2boxes(detections=detections) + tensors = np.hstack( + ( + detections.xyxy, + detections.confidence[:, np.newaxis], + detections.class_id[:, np.newaxis], + ) + ) tracks = self.update_with_tensors(tensors=tensors) if len(tracks) > 0: @@ -311,11 +318,12 @@ def reset(self): ensuring the tracker starts with a clean state for each new video. """ self.frame_id = 0 + BaseTrack.reset_counter() + self.internal_id_counter.reset() + self.external_id_counter.reset() self.tracked_tracks: List[STrack] = [] self.lost_tracks: List[STrack] = [] self.removed_tracks: List[STrack] = [] - BaseTrack.reset_counter() - STrack.reset_external_counter() def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: """ @@ -353,7 +361,14 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: if len(dets) > 0: """Detections""" detections = [ - STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames) + STrack( + STrack.tlbr_to_tlwh(tlbr), + s, + c, + self.minimum_consecutive_frames, + self.internal_id_counter, + self.external_id_counter, + ) for (tlbr, s, c) in zip(dets, scores_keep, class_ids_keep) ] else: @@ -387,7 +402,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: track.update(detections[idet], self.frame_id) activated_starcks.append(track) else: - track.re_activate(det, self.frame_id, new_id=False) + track.re_activate(det, self.frame_id) refind_stracks.append(track) """ Step 3: Second association, with low score detection boxes""" @@ -395,7 +410,14 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: if len(dets_second) > 0: """Detections""" detections_second = [ - STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames) + STrack( + STrack.tlbr_to_tlwh(tlbr), + s, + c, + self.minimum_consecutive_frames, + self.internal_id_counter, + self.external_id_counter, + ) for (tlbr, s, c) in zip(dets_second, scores_second, class_ids_second) ] else: @@ -416,7 +438,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: track.update(det, self.frame_id) activated_starcks.append(track) else: - track.re_activate(det, self.frame_id, new_id=False) + track.re_activate(det, self.frame_id) refind_stracks.append(track) for it in u_track: diff --git a/supervision/tracker/byte_tracker/matching.py b/supervision/tracker/byte_tracker/matching.py index 24abe224c..f791e518f 100644 --- a/supervision/tracker/byte_tracker/matching.py +++ b/supervision/tracker/byte_tracker/matching.py @@ -20,7 +20,7 @@ def indices_to_matches( def linear_assignment( cost_matrix: np.ndarray, thresh: float -) -> [np.ndarray, Tuple[int], Tuple[int, int]]: +) -> Tuple[np.ndarray, Tuple[int], Tuple[int, int]]: if cost_matrix.size == 0: return ( np.empty((0, 2), dtype=int), diff --git a/test/tracker/__init__.py b/test/tracker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/tracker/test_byte_tracker.py b/test/tracker/test_byte_tracker.py new file mode 100644 index 000000000..98efeb093 --- /dev/null +++ b/test/tracker/test_byte_tracker.py @@ -0,0 +1,40 @@ +from typing import List + +import numpy as np +import pytest + +import supervision as sv + + +@pytest.mark.parametrize( + "detections, expected_results", + [ + ( + [ + sv.Detections( + xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]), + class_id=np.array([1, 1]), + confidence=np.array([1, 1]), + ), + sv.Detections( + xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]), + class_id=np.array([1, 1]), + confidence=np.array([1, 1]), + ), + ], + sv.Detections( + xyxy=np.array([[10, 10, 20, 20], [30, 30, 40, 40]]), + class_id=np.array([1, 1]), + confidence=np.array([1, 1]), + tracker_id=np.array([1, 2]), + ), + ), + ], +) +def test_byte_tracker( + detections: List[sv.Detections], + expected_results: sv.Detections, +) -> None: + byte_tracker = sv.ByteTrack() + tracked_detections = [byte_tracker.update_with_detections(d) for d in detections] + assert tracked_detections[-1] == expected_results