diff --git a/poetry.lock b/poetry.lock index 6e2817a5d..d7e29a430 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4037,29 +4037,29 @@ files = [ [[package]] name = "ruff" -version = "0.7.1" +version = "0.7.2" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.7.1-py3-none-linux_armv6l.whl", hash = "sha256:cb1bc5ed9403daa7da05475d615739cc0212e861b7306f314379d958592aaa89"}, - {file = "ruff-0.7.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:27c1c52a8d199a257ff1e5582d078eab7145129aa02721815ca8fa4f9612dc35"}, - {file = "ruff-0.7.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:588a34e1ef2ea55b4ddfec26bbe76bc866e92523d8c6cdec5e8aceefeff02d99"}, - {file = "ruff-0.7.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94fc32f9cdf72dc75c451e5f072758b118ab8100727168a3df58502b43a599ca"}, - {file = "ruff-0.7.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:985818742b833bffa543a84d1cc11b5e6871de1b4e0ac3060a59a2bae3969250"}, - {file = "ruff-0.7.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32f1e8a192e261366c702c5fb2ece9f68d26625f198a25c408861c16dc2dea9c"}, - {file = "ruff-0.7.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:699085bf05819588551b11751eff33e9ca58b1b86a6843e1b082a7de40da1565"}, - {file = "ruff-0.7.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:344cc2b0814047dc8c3a8ff2cd1f3d808bb23c6658db830d25147339d9bf9ea7"}, - {file = "ruff-0.7.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4316bbf69d5a859cc937890c7ac7a6551252b6a01b1d2c97e8fc96e45a7c8b4a"}, - {file = "ruff-0.7.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79d3af9dca4c56043e738a4d6dd1e9444b6d6c10598ac52d146e331eb155a8ad"}, - {file = "ruff-0.7.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c5c121b46abde94a505175524e51891f829414e093cd8326d6e741ecfc0a9112"}, - {file = "ruff-0.7.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8422104078324ea250886954e48f1373a8fe7de59283d747c3a7eca050b4e378"}, - {file = "ruff-0.7.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:56aad830af8a9db644e80098fe4984a948e2b6fc2e73891538f43bbe478461b8"}, - {file = "ruff-0.7.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:658304f02f68d3a83c998ad8bf91f9b4f53e93e5412b8f2388359d55869727fd"}, - {file = "ruff-0.7.1-py3-none-win32.whl", hash = "sha256:b517a2011333eb7ce2d402652ecaa0ac1a30c114fbbd55c6b8ee466a7f600ee9"}, - {file = "ruff-0.7.1-py3-none-win_amd64.whl", hash = "sha256:f38c41fcde1728736b4eb2b18850f6d1e3eedd9678c914dede554a70d5241307"}, - {file = "ruff-0.7.1-py3-none-win_arm64.whl", hash = "sha256:19aa200ec824c0f36d0c9114c8ec0087082021732979a359d6f3c390a6ff2a37"}, - {file = "ruff-0.7.1.tar.gz", hash = "sha256:9d8a41d4aa2dad1575adb98a82870cf5db5f76b2938cf2206c22c940034a36f4"}, + {file = "ruff-0.7.2-py3-none-linux_armv6l.whl", hash = "sha256:b73f873b5f52092e63ed540adefc3c36f1f803790ecf2590e1df8bf0a9f72cb8"}, + {file = "ruff-0.7.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5b813ef26db1015953daf476202585512afd6a6862a02cde63f3bafb53d0b2d4"}, + {file = "ruff-0.7.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:853277dbd9675810c6826dad7a428d52a11760744508340e66bf46f8be9701d9"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21aae53ab1490a52bf4e3bf520c10ce120987b047c494cacf4edad0ba0888da2"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ccc7e0fc6e0cb3168443eeadb6445285abaae75142ee22b2b72c27d790ab60ba"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd77877a4e43b3a98e5ef4715ba3862105e299af0c48942cc6d51ba3d97dc859"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e00163fb897d35523c70d71a46fbaa43bf7bf9af0f4534c53ea5b96b2e03397b"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3c54b538633482dc342e9b634d91168fe8cc56b30a4b4f99287f4e339103e88"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b792468e9804a204be221b14257566669d1db5c00d6bb335996e5cd7004ba80"}, + {file = "ruff-0.7.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dba53ed84ac19ae4bfb4ea4bf0172550a2285fa27fbb13e3746f04c80f7fa088"}, + {file = "ruff-0.7.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b19fafe261bf741bca2764c14cbb4ee1819b67adb63ebc2db6401dcd652e3748"}, + {file = "ruff-0.7.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:28bd8220f4d8f79d590db9e2f6a0674f75ddbc3847277dd44ac1f8d30684b828"}, + {file = "ruff-0.7.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9fd67094e77efbea932e62b5d2483006154794040abb3a5072e659096415ae1e"}, + {file = "ruff-0.7.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:576305393998b7bd6c46018f8104ea3a9cb3fa7908c21d8580e3274a3b04b691"}, + {file = "ruff-0.7.2-py3-none-win32.whl", hash = "sha256:fa993cfc9f0ff11187e82de874dfc3611df80852540331bc85c75809c93253a8"}, + {file = "ruff-0.7.2-py3-none-win_amd64.whl", hash = "sha256:dd8800cbe0254e06b8fec585e97554047fb82c894973f7ff18558eee33d1cb88"}, + {file = "ruff-0.7.2-py3-none-win_arm64.whl", hash = "sha256:bb8368cd45bba3f57bb29cbb8d64b4a33f8415d0149d2655c5c8539452ce7760"}, + {file = "ruff-0.7.2.tar.gz", hash = "sha256:2b14e77293380e475b4e3a7a368e14549288ed2931fce259a6f99978669e844f"}, ] [[package]] diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 113948fc9..32753a30a 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -32,8 +32,10 @@ extract_ultralytics_masks, get_data_item, is_data_equal, + is_metadata_equal, mask_to_xyxy, merge_data, + merge_metadata, process_roboflow_result, xywh_to_xyxy, ) @@ -125,6 +127,9 @@ class simplifies data manipulation and filtering, providing a uniform API for data (Dict[str, Union[np.ndarray, List]]): A dictionary containing additional data where each key is a string representing the data type, and the value is either a NumPy array or a list of corresponding data. + metadata (Dict[str, Any]): A dictionary containing collection-level metadata + that applies to the entire set of detections. This may include information such + as the video name, camera parameters, timestamp, or other global metadata. """ # noqa: E501 // docs xyxy: np.ndarray @@ -133,6 +138,7 @@ class simplifies data manipulation and filtering, providing a uniform API for class_id: Optional[np.ndarray] = None tracker_id: Optional[np.ndarray] = None data: Dict[str, Union[np.ndarray, List]] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): validate_detections_fields( @@ -185,6 +191,7 @@ def __eq__(self, other: Detections): np.array_equal(self.confidence, other.confidence), np.array_equal(self.tracker_id, other.tracker_id), is_data_equal(self.data, other.data), + is_metadata_equal(self.metadata, other.metadata), ] ) @@ -985,6 +992,7 @@ def is_empty(self) -> bool: """ empty_detections = Detections.empty() empty_detections.data = self.data + empty_detections.metadata = self.metadata return self == empty_detections @classmethod @@ -1078,6 +1086,9 @@ def stack_or_none(name: str): data = merge_data([d.data for d in detections_list]) + metadata_list = [detections.metadata for detections in detections_list] + metadata = merge_metadata(metadata_list) + return cls( xyxy=xyxy, mask=mask, @@ -1085,6 +1096,7 @@ def stack_or_none(name: str): class_id=class_id, tracker_id=tracker_id, data=data, + metadata=metadata, ) def get_anchors_coordinates(self, anchor: Position) -> np.ndarray: @@ -1198,6 +1210,7 @@ def __getitem__( class_id=self.class_id[index] if self.class_id is not None else None, tracker_id=self.tracker_id[index] if self.tracker_id is not None else None, data=get_data_item(self.data, index), + metadata=self.metadata, ) def __setitem__(self, key: str, value: Union[np.ndarray, List]): @@ -1459,6 +1472,8 @@ def merge_inner_detection_object_pair( else: winning_detection = detections_2 + metadata = merge_metadata([detections_1.metadata, detections_2.metadata]) + return Detections( xyxy=merged_xyxy, mask=merged_mask, @@ -1466,6 +1481,7 @@ def merge_inner_detection_object_pair( class_id=winning_detection.class_id, tracker_id=winning_detection.tracker_id, data=winning_detection.data, + metadata=metadata, ) diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index 69cdacc7e..fc4458fa2 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -1,5 +1,5 @@ from itertools import chain -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import cv2 import numpy as np @@ -807,12 +807,36 @@ def is_data_equal(data_a: Dict[str, np.ndarray], data_b: Dict[str, np.ndarray]) ) +def is_metadata_equal(metadata_a: Dict[str, Any], metadata_b: Dict[str, Any]) -> bool: + """ + Compares the metadata payloads of two Detections instances. + + Args: + metadata_a, metadata_b: The metadata payloads of the instances. + + Returns: + True if the metadata payloads are equal, False otherwise. + """ + return set(metadata_a.keys()) == set(metadata_b.keys()) and all( + np.array_equal(metadata_a[key], metadata_b[key]) + if ( + isinstance(metadata_a[key], np.ndarray) + and isinstance(metadata_b[key], np.ndarray) + ) + else metadata_a[key] == metadata_b[key] + for key in metadata_a + ) + + def merge_data( data_list: List[Dict[str, Union[npt.NDArray[np.generic], List]]], ) -> Dict[str, Union[npt.NDArray[np.generic], List]]: """ Merges the data payloads of a list of Detections instances. + Warning: Assumes that empty detections were filtered-out before passing data to + this function. + Args: data_list: The data payloads of the Detections instances. Each data payload is a dictionary with the same keys, and the values are either lists or @@ -865,6 +889,45 @@ def merge_data( return merged_data +def merge_metadata(metadata_list: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Merge metadata from a list of metadata dictionaries. + + This function combines the metadata dictionaries. If a key appears in more than one + dictionary, the values must be identical for the merge to succeed. + + Warning: Assumes that empty detections were filtered-out before passing metadata to + this function. + + Args: + metadata_list (List[Dict[str, Any]]): A list of metadata dictionaries to merge. + + Returns: + Dict[str, Any]: A single merged metadata dictionary. + + Raises: + ValueError: If there are conflicting values for the same key or if + dictionaries have different keys. + """ + if not metadata_list: + return {} + + all_keys_sets = [set(metadata.keys()) for metadata in metadata_list] + if not all(keys_set == all_keys_sets[0] for keys_set in all_keys_sets): + raise ValueError("All metadata dictionaries must have the same keys to merge.") + + merged_metadata: Dict[str, Any] = {} + for metadata in metadata_list: + for key, value in metadata.items(): + if key in merged_metadata: + if merged_metadata[key] != value: + raise ValueError(f"Conflicting metadata for key: '{key}'.") + else: + merged_metadata[key] = value + + return merged_metadata + + def get_data_item( data: Dict[str, Union[np.ndarray, List]], index: Union[int, slice, List[int], np.ndarray], diff --git a/test/utils/test_internal.py b/test/utils/test_internal.py index eee614e6c..872822a7c 100644 --- a/test/utils/test_internal.py +++ b/test/utils/test_internal.py @@ -121,7 +121,15 @@ def __private_property(self): ( Detections.empty(), False, - {"xyxy", "class_id", "confidence", "mask", "tracker_id", "data"}, + { + "xyxy", + "class_id", + "confidence", + "mask", + "tracker_id", + "data", + "metadata", + }, DoesNotRaise(), ), ( @@ -134,6 +142,7 @@ def __private_property(self): "mask", "tracker_id", "data", + "metadata", "area", "box_area", }, @@ -149,6 +158,7 @@ def __private_property(self): "mask", "tracker_id", "data", + "metadata", }, DoesNotRaise(), ), @@ -169,13 +179,22 @@ def __private_property(self): "mask", "tracker_id", "data", + "metadata", }, DoesNotRaise(), ), ( Detections.empty(), False, - {"xyxy", "class_id", "confidence", "mask", "tracker_id", "data"}, + { + "xyxy", + "class_id", + "confidence", + "mask", + "tracker_id", + "data", + "metadata", + }, DoesNotRaise(), ), ],