From c17de5c11b220b07ea8343ee3ebbedfa52130b22 Mon Sep 17 00:00:00 2001 From: LinasKo Date: Fri, 18 Oct 2024 14:37:36 +0300 Subject: [PATCH 1/3] Add precision and recall metrics --- docs/metrics/precision.md | 18 + docs/metrics/recall.md | 18 + mkdocs.yml | 2 + supervision/metrics/precision.py | 547 +++++++++++++++++++++++++++++++ supervision/metrics/recall.py | 545 ++++++++++++++++++++++++++++++ 5 files changed, 1130 insertions(+) create mode 100644 docs/metrics/precision.md create mode 100644 docs/metrics/recall.md create mode 100644 supervision/metrics/precision.py create mode 100644 supervision/metrics/recall.py diff --git a/docs/metrics/precision.md b/docs/metrics/precision.md new file mode 100644 index 000000000..ca318f8fb --- /dev/null +++ b/docs/metrics/precision.md @@ -0,0 +1,18 @@ +--- +comments: true +status: new +--- + +# F1 Score + +
+

Precision

+
+ +:::supervision.metrics.precision.Precision + +
+

PrecisionResult

+
+ +:::supervision.metrics.precision.PrecisionResult diff --git a/docs/metrics/recall.md b/docs/metrics/recall.md new file mode 100644 index 000000000..5baa4d3ee --- /dev/null +++ b/docs/metrics/recall.md @@ -0,0 +1,18 @@ +--- +comments: true +status: new +--- + +# F1 Score + +
+

Recall

+
+ +:::supervision.metrics.recall.Recall + +
+

RecallResult

+
+ +:::supervision.metrics.recall.RecallResult diff --git a/mkdocs.yml b/mkdocs.yml index 3cd867590..a3c9c1caa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -66,6 +66,8 @@ nav: - Utils: datasets/utils.md - Metrics: - mAP: metrics/mean_average_precision.md + - Precision: metrics/precision.md + - Recall: metrics/recall.md - F1 Score: metrics/f1_score.md - Legacy Metrics: detection/metrics.md - Utils: diff --git a/supervision/metrics/precision.py b/supervision/metrics/precision.py new file mode 100644 index 000000000..ba441831a --- /dev/null +++ b/supervision/metrics/precision.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np +from matplotlib import pyplot as plt + +from supervision.config import ORIENTED_BOX_COORDINATES +from supervision.detection.core import Detections +from supervision.detection.utils import box_iou_batch, mask_iou_batch +from supervision.draw.color import LEGACY_COLOR_PALETTE +from supervision.metrics.core import AveragingMethod, Metric, MetricTarget +from supervision.metrics.utils.object_size import ( + ObjectSizeCategory, + get_detection_size_category, +) +from supervision.metrics.utils.utils import ensure_pandas_installed + +if TYPE_CHECKING: + import pandas as pd + + +class Precision(Metric): + def __init__( + self, + metric_target: MetricTarget = MetricTarget.BOXES, + averaging_method: AveragingMethod = AveragingMethod.WEIGHTED, + ): + self._metric_target = metric_target + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + raise NotImplementedError( + "Precision is not implemented for oriented bounding boxes." + ) + + self._metric_target = metric_target + self.averaging_method = averaging_method + self._predictions_list: List[Detections] = [] + self._targets_list: List[Detections] = [] + + def reset(self) -> None: + self._predictions_list = [] + self._targets_list = [] + + def update( + self, + predictions: Union[Detections, List[Detections]], + targets: Union[Detections, List[Detections]], + ) -> Precision: + if not isinstance(predictions, list): + predictions = [predictions] + if not isinstance(targets, list): + targets = [targets] + + if len(predictions) != len(targets): + raise ValueError( + f"The number of predictions ({len(predictions)}) and" + f" targets ({len(targets)}) during the update must be the same." + ) + + self._predictions_list.extend(predictions) + self._targets_list.extend(targets) + + return self + + def compute(self) -> PrecisionResult: + result = self._compute(self._predictions_list, self._targets_list) + + small_predictions, small_targets = self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL + ) + result.small_objects = self._compute(small_predictions, small_targets) + + medium_predictions, medium_targets = ( + self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM + ) + ) + result.medium_objects = self._compute(medium_predictions, medium_targets) + + large_predictions, large_targets = self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE + ) + result.large_objects = self._compute(large_predictions, large_targets) + + return result + + def _compute( + self, predictions_list: List[Detections], targets_list: List[Detections] + ) -> PrecisionResult: + iou_thresholds = np.linspace(0.5, 0.95, 10) + stats = [] + + for predictions, targets in zip(predictions_list, targets_list): + prediction_contents = self._detections_content(predictions) + target_contents = self._detections_content(targets) + + if len(targets) > 0: + if len(predictions) == 0: + stats.append( + ( + np.zeros((0, iou_thresholds.size), dtype=bool), + np.zeros((0,), dtype=np.float32), + np.zeros((0,), dtype=int), + targets.class_id, + ) + ) + + else: + if self._metric_target == MetricTarget.BOXES: + iou = box_iou_batch(target_contents, prediction_contents) + elif self._metric_target == MetricTarget.MASKS: + iou = mask_iou_batch(target_contents, prediction_contents) + else: + raise NotImplementedError( + "Unsupported metric target for IoU calculation" + ) + + matches = self._match_detection_batch( + predictions.class_id, targets.class_id, iou, iou_thresholds + ) + stats.append( + ( + matches, + predictions.confidence, + predictions.class_id, + targets.class_id, + ) + ) + + if not stats: + return PrecisionResult( + metric_target=self._metric_target, + averaging_method=self.averaging_method, + precision_scores=np.zeros(iou_thresholds.shape[0]), + precision_per_class=np.zeros((0, iou_thresholds.shape[0])), + iou_thresholds=iou_thresholds, + matched_classes=np.array([], dtype=int), + small_objects=None, + medium_objects=None, + large_objects=None, + ) + + concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)] + precision_scores, precision_per_class, unique_classes = ( + self._compute_precision_for_classes(*concatenated_stats) + ) + + return PrecisionResult( + metric_target=self._metric_target, + averaging_method=self.averaging_method, + precision_scores=precision_scores, + precision_per_class=precision_per_class, + iou_thresholds=iou_thresholds, + matched_classes=unique_classes, + small_objects=None, + medium_objects=None, + large_objects=None, + ) + + def _compute_precision_for_classes( + self, + matches: np.ndarray, + prediction_confidence: np.ndarray, + prediction_class_ids: np.ndarray, + true_class_ids: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + sorted_indices = np.argsort(-prediction_confidence) + matches = matches[sorted_indices] + prediction_class_ids = prediction_class_ids[sorted_indices] + unique_classes, class_counts = np.unique(true_class_ids, return_counts=True) + + # Shape: PxTh,P,C,C -> CxThx3 + confusion_matrix = self._compute_confusion_matrix( + matches, prediction_class_ids, unique_classes, class_counts + ) + + # Shape: CxThx3 -> CxTh + precision_per_class = self._compute_precision(confusion_matrix) + + # Shape: CxTh -> Th + if self.averaging_method == AveragingMethod.MACRO: + precision_scores = np.mean(precision_per_class, axis=0) + elif self.averaging_method == AveragingMethod.MICRO: + confusion_matrix_merged = confusion_matrix.sum(0) + precision_scores = self._compute_precision(confusion_matrix_merged) + elif self.averaging_method == AveragingMethod.WEIGHTED: + class_counts = class_counts.astype(np.float32) + precision_scores = np.average( + precision_per_class, axis=0, weights=class_counts + ) + + return precision_scores, precision_per_class, unique_classes + + @staticmethod + def _match_detection_batch( + predictions_classes: np.ndarray, + target_classes: np.ndarray, + iou: np.ndarray, + iou_thresholds: np.ndarray, + ) -> np.ndarray: + num_predictions, num_iou_levels = ( + predictions_classes.shape[0], + iou_thresholds.shape[0], + ) + correct = np.zeros((num_predictions, num_iou_levels), dtype=bool) + correct_class = target_classes[:, None] == predictions_classes + + for i, iou_level in enumerate(iou_thresholds): + matched_indices = np.where((iou >= iou_level) & correct_class) + + if matched_indices[0].shape[0]: + combined_indices = np.stack(matched_indices, axis=1) + iou_values = iou[matched_indices][:, None] + matches = np.hstack([combined_indices, iou_values]) + + if matched_indices[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + + correct[matches[:, 1].astype(int), i] = True + + return correct + + @staticmethod + def _compute_confusion_matrix( + sorted_matches: np.ndarray, + sorted_prediction_class_ids: np.ndarray, + unique_classes: np.ndarray, + class_counts: np.ndarray, + ) -> np.ndarray: + """ + Compute the confusion matrix for each class and IoU threshold. + + Assumes the matches and prediction_class_ids are sorted by confidence + in descending order. + + Arguments: + sorted_matches: np.ndarray, bool, shape (P, Th), that is True + if the prediction is a true positive at the given IoU threshold. + sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing + the class id for each prediction. + unique_classes: np.ndarray, int, shape (C,), containing the unique + class ids. + class_counts: np.ndarray, int, shape (C,), containing the number + of true instances for each class. + + Returns: + np.ndarray, shape (C, Th, 3), containing the true positives, false + positives, and false negatives for each class and IoU threshold. + """ + + num_thresholds = sorted_matches.shape[1] + num_classes = unique_classes.shape[0] + + confusion_matrix = np.zeros((num_classes, num_thresholds, 3)) + for class_idx, class_id in enumerate(unique_classes): + is_class = sorted_prediction_class_ids == class_id + num_true = class_counts[class_idx] + num_predictions = is_class.sum() + + if num_predictions == 0: + true_positives = np.zeros(num_thresholds) + false_positives = np.zeros(num_thresholds) + false_negatives = np.full(num_thresholds, num_true) + elif num_true == 0: + true_positives = np.zeros(num_thresholds) + false_positives = np.full(num_thresholds, num_predictions) + false_negatives = np.zeros(num_thresholds) + else: + true_positives = sorted_matches[is_class].sum(0) + false_positives = (1 - sorted_matches[is_class]).sum(0) + false_negatives = num_true - true_positives + confusion_matrix[class_idx] = np.stack( + [true_positives, false_positives, false_negatives], axis=1 + ) + + return confusion_matrix + + @staticmethod + def _compute_precision(confusion_matrix: np.ndarray) -> np.ndarray: + """ + Broadcastable function, computing the precision from the confusion matrix. + + Arguments: + confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension + contains the true positives, false positives, and false negatives. + + Returns: + np.ndarray, shape (N, ...), containing the precision for each element. + """ + if not confusion_matrix.shape[-1] == 3: + raise ValueError( + f"Confusion matrix must have shape (..., 3), got " + f"{confusion_matrix.shape}" + ) + true_positives = confusion_matrix[..., 0] + false_positives = confusion_matrix[..., 1] + + denominator = true_positives + false_positives + precision = np.where(denominator == 0, 0, true_positives / denominator) + + return precision + + def _detections_content(self, detections: Detections) -> np.ndarray: + """Return boxes, masks or oriented bounding boxes from detections.""" + if self._metric_target == MetricTarget.BOXES: + return detections.xyxy + if self._metric_target == MetricTarget.MASKS: + return ( + detections.mask + if detections.mask is not None + else np.empty((0, 0, 0), dtype=bool) + ) + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + if obb := detections.data.get(ORIENTED_BOX_COORDINATES): + return np.ndarray(obb, dtype=np.float32) + return np.empty((0, 8), dtype=np.float32) + raise ValueError(f"Invalid metric target: {self._metric_target}") + + def _filter_detections_by_size( + self, detections: Detections, size_category: ObjectSizeCategory + ) -> Detections: + """Return a copy of detections with contents filtered by object size.""" + new_detections = deepcopy(detections) + if detections.is_empty() or size_category == ObjectSizeCategory.ANY: + return new_detections + + sizes = get_detection_size_category(new_detections, self._metric_target) + size_mask = sizes == size_category.value + + new_detections.xyxy = new_detections.xyxy[size_mask] + if new_detections.mask is not None: + new_detections.mask = new_detections.mask[size_mask] + if new_detections.class_id is not None: + new_detections.class_id = new_detections.class_id[size_mask] + if new_detections.confidence is not None: + new_detections.confidence = new_detections.confidence[size_mask] + if new_detections.tracker_id is not None: + new_detections.tracker_id = new_detections.tracker_id[size_mask] + if new_detections.data is not None: + for key, value in new_detections.data.items(): + new_detections.data[key] = np.array(value)[size_mask] + + return new_detections + + def _filter_predictions_and_targets_by_size( + self, + predictions_list: List[Detections], + targets_list: List[Detections], + size_category: ObjectSizeCategory, + ) -> Tuple[List[Detections], List[Detections]]: + """ + Filter predictions and targets by object size category. + """ + new_predictions_list = [] + new_targets_list = [] + for predictions, targets in zip(predictions_list, targets_list): + new_predictions_list.append( + self._filter_detections_by_size(predictions, size_category) + ) + new_targets_list.append( + self._filter_detections_by_size(targets, size_category) + ) + return new_predictions_list, new_targets_list + + +@dataclass +class PrecisionResult: + """ + The results of the precision metric calculation. + + Defaults to `0` if no detections or targets were provided. + Provides a custom `__str__` method for pretty printing. + + Attributes: + metric_target (MetricTarget): the type of data used for the metric - + boxes, masks or oriented bounding boxes. + averaging_method (AveragingMethod): the averaging method used to compute the + precision. Determines how the precision is aggregated across classes. + precision_at_50 (float): the precision at IoU threshold of `0.5`. + precision_at_75 (float): the precision at IoU threshold of `0.75`. + precision_scores (np.ndarray): the precision scores at each IoU threshold. + Shape: `(num_iou_thresholds,)` + precision_per_class (np.ndarray): the precision scores per class and + IoU threshold. Shape: `(num_target_classes, num_iou_thresholds)` + iou_thresholds (np.ndarray): the IoU thresholds used in the calculations. + matched_classes (np.ndarray): the class IDs of all matched classes. + Corresponds to the rows of `precision_per_class`. + small_objects (Optional[PrecisionResult]): the Precision metric results + for small objects. + medium_objects (Optional[PrecisionResult]): the Precision metric results + for medium objects. + large_objects (Optional[PrecisionResult]): the Precision metric results + for large objects. + """ + + metric_target: MetricTarget + averaging_method: AveragingMethod + + @property + def precision_at_50(self) -> float: + return self.precision_scores[0] + + @property + def precision_at_75(self) -> float: + return self.precision_scores[5] + + precision_scores: np.ndarray + precision_per_class: np.ndarray + iou_thresholds: np.ndarray + matched_classes: np.ndarray + + small_objects: Optional[PrecisionResult] + medium_objects: Optional[PrecisionResult] + large_objects: Optional[PrecisionResult] + + def __str__(self) -> str: + """ + Format as a pretty string. + + Example: + ```python + print(precision_result) + ``` + """ + out_str = ( + f"{self.__class__.__name__}:\n" + f"Metric target: {self.metric_target}\n" + f"Averaging method: {self.averaging_method}\n" + f"P @ 50: {self.precision_at_50:.4f}\n" + f"P @ 75: {self.precision_at_75:.4f}\n" + f"P @ thresh: {self.precision_scores}\n" + f"IoU thresh: {self.iou_thresholds}\n" + f"Precision per class:\n" + ) + if self.precision_per_class.size == 0: + out_str += " No results\n" + for class_id, precision_of_class in zip( + self.matched_classes, self.precision_per_class + ): + out_str += f" {class_id}: {precision_of_class}\n" + + indent = " " + if self.small_objects is not None: + indented = indent + str(self.small_objects).replace("\n", f"\n{indent}") + out_str += f"\nSmall objects:\n{indented}" + if self.medium_objects is not None: + indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}") + out_str += f"\nMedium objects:\n{indented}" + if self.large_objects is not None: + indented = indent + str(self.large_objects).replace("\n", f"\n{indent}") + out_str += f"\nLarge objects:\n{indented}" + + return out_str + + def to_pandas(self) -> "pd.DataFrame": + """ + Convert the result to a pandas DataFrame. + + Returns: + (pd.DataFrame): The result as a DataFrame. + """ + ensure_pandas_installed() + import pandas as pd + + pandas_data = { + "P@50": self.precision_at_50, + "P@75": self.precision_at_75, + } + + if self.small_objects is not None: + small_objects_df = self.small_objects.to_pandas() + for key, value in small_objects_df.items(): + pandas_data[f"small_objects_{key}"] = value + if self.medium_objects is not None: + medium_objects_df = self.medium_objects.to_pandas() + for key, value in medium_objects_df.items(): + pandas_data[f"medium_objects_{key}"] = value + if self.large_objects is not None: + large_objects_df = self.large_objects.to_pandas() + for key, value in large_objects_df.items(): + pandas_data[f"large_objects_{key}"] = value + + return pd.DataFrame(pandas_data, index=[0]) + + def plot(self): + """ + Plot the precision results. + """ + + labels = ["Precision@50", "Precision@75"] + values = [self.precision_at_50, self.precision_at_75] + colors = [LEGACY_COLOR_PALETTE[0]] * 2 + + if self.small_objects is not None: + small_objects = self.small_objects + labels += ["Small: P@50", "Small: P@75"] + values += [small_objects.precision_at_50, small_objects.precision_at_75] + colors += [LEGACY_COLOR_PALETTE[3]] * 2 + + if self.medium_objects is not None: + medium_objects = self.medium_objects + labels += ["Medium: P@50", "Medium: P@75"] + values += [medium_objects.precision_at_50, medium_objects.precision_at_75] + colors += [LEGACY_COLOR_PALETTE[2]] * 2 + + if self.large_objects is not None: + large_objects = self.large_objects + labels += ["Large: P@50", "Large: P@75"] + values += [large_objects.precision_at_50, large_objects.precision_at_75] + colors += [LEGACY_COLOR_PALETTE[4]] * 2 + + plt.rcParams["font.family"] = "monospace" + + _, ax = plt.subplots(figsize=(10, 6)) + ax.set_ylim(0, 1) + ax.set_ylabel("Value", fontweight="bold") + title = ( + f"Precision, by Object Size" + f"\n(target: {self.metric_target.value}," + f" averaging: {self.averaging_method.value})" + ) + ax.set_title(title, fontweight="bold") + + x_positions = range(len(labels)) + bars = ax.bar(x_positions, values, color=colors, align="center") + + ax.set_xticks(x_positions) + ax.set_xticklabels(labels, rotation=45, ha="right") + + for bar in bars: + y_value = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2, + y_value + 0.02, + f"{y_value:.2f}", + ha="center", + va="bottom", + ) + + plt.rcParams["font.family"] = "sans-serif" + + plt.tight_layout() + plt.show() diff --git a/supervision/metrics/recall.py b/supervision/metrics/recall.py new file mode 100644 index 000000000..7c90859cc --- /dev/null +++ b/supervision/metrics/recall.py @@ -0,0 +1,545 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np +from matplotlib import pyplot as plt + +from supervision.config import ORIENTED_BOX_COORDINATES +from supervision.detection.core import Detections +from supervision.detection.utils import box_iou_batch, mask_iou_batch +from supervision.draw.color import LEGACY_COLOR_PALETTE +from supervision.metrics.core import AveragingMethod, Metric, MetricTarget +from supervision.metrics.utils.object_size import ( + ObjectSizeCategory, + get_detection_size_category, +) +from supervision.metrics.utils.utils import ensure_pandas_installed + +if TYPE_CHECKING: + import pandas as pd + + +class Recall(Metric): + def __init__( + self, + metric_target: MetricTarget = MetricTarget.BOXES, + averaging_method: AveragingMethod = AveragingMethod.WEIGHTED, + ): + self._metric_target = metric_target + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + raise NotImplementedError( + "Recall is not implemented for oriented bounding boxes." + ) + + self._metric_target = metric_target + self.averaging_method = averaging_method + self._predictions_list: List[Detections] = [] + self._targets_list: List[Detections] = [] + + def reset(self) -> None: + self._predictions_list = [] + self._targets_list = [] + + def update( + self, + predictions: Union[Detections, List[Detections]], + targets: Union[Detections, List[Detections]], + ) -> Recall: + if not isinstance(predictions, list): + predictions = [predictions] + if not isinstance(targets, list): + targets = [targets] + + if len(predictions) != len(targets): + raise ValueError( + f"The number of predictions ({len(predictions)}) and" + f" targets ({len(targets)}) during the update must be the same." + ) + + self._predictions_list.extend(predictions) + self._targets_list.extend(targets) + + return self + + def compute(self) -> RecallResult: + result = self._compute(self._predictions_list, self._targets_list) + + small_predictions, small_targets = self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.SMALL + ) + result.small_objects = self._compute(small_predictions, small_targets) + + medium_predictions, medium_targets = ( + self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.MEDIUM + ) + ) + result.medium_objects = self._compute(medium_predictions, medium_targets) + + large_predictions, large_targets = self._filter_predictions_and_targets_by_size( + self._predictions_list, self._targets_list, ObjectSizeCategory.LARGE + ) + result.large_objects = self._compute(large_predictions, large_targets) + + return result + + def _compute( + self, predictions_list: List[Detections], targets_list: List[Detections] + ) -> RecallResult: + iou_thresholds = np.linspace(0.5, 0.95, 10) + stats = [] + + for predictions, targets in zip(predictions_list, targets_list): + prediction_contents = self._detections_content(predictions) + target_contents = self._detections_content(targets) + + if len(targets) > 0: + if len(predictions) == 0: + stats.append( + ( + np.zeros((0, iou_thresholds.size), dtype=bool), + np.zeros((0,), dtype=np.float32), + np.zeros((0,), dtype=int), + targets.class_id, + ) + ) + + else: + if self._metric_target == MetricTarget.BOXES: + iou = box_iou_batch(target_contents, prediction_contents) + elif self._metric_target == MetricTarget.MASKS: + iou = mask_iou_batch(target_contents, prediction_contents) + else: + raise NotImplementedError( + "Unsupported metric target for IoU calculation" + ) + + matches = self._match_detection_batch( + predictions.class_id, targets.class_id, iou, iou_thresholds + ) + stats.append( + ( + matches, + predictions.confidence, + predictions.class_id, + targets.class_id, + ) + ) + + if not stats: + return RecallResult( + metric_target=self._metric_target, + averaging_method=self.averaging_method, + recall_scores=np.zeros(iou_thresholds.shape[0]), + recall_per_class=np.zeros((0, iou_thresholds.shape[0])), + iou_thresholds=iou_thresholds, + matched_classes=np.array([], dtype=int), + small_objects=None, + medium_objects=None, + large_objects=None, + ) + + concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)] + recall_scores, recall_per_class, unique_classes = ( + self._compute_recall_for_classes(*concatenated_stats) + ) + + return RecallResult( + metric_target=self._metric_target, + averaging_method=self.averaging_method, + recall_scores=recall_scores, + recall_per_class=recall_per_class, + iou_thresholds=iou_thresholds, + matched_classes=unique_classes, + small_objects=None, + medium_objects=None, + large_objects=None, + ) + + def _compute_recall_for_classes( + self, + matches: np.ndarray, + prediction_confidence: np.ndarray, + prediction_class_ids: np.ndarray, + true_class_ids: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + sorted_indices = np.argsort(-prediction_confidence) + matches = matches[sorted_indices] + prediction_class_ids = prediction_class_ids[sorted_indices] + unique_classes, class_counts = np.unique(true_class_ids, return_counts=True) + + # Shape: PxTh,P,C,C -> CxThx3 + confusion_matrix = self._compute_confusion_matrix( + matches, prediction_class_ids, unique_classes, class_counts + ) + + # Shape: CxThx3 -> CxTh + recall_per_class = self._compute_recall(confusion_matrix) + + # Shape: CxTh -> Th + if self.averaging_method == AveragingMethod.MACRO: + recall_scores = np.mean(recall_per_class, axis=0) + elif self.averaging_method == AveragingMethod.MICRO: + confusion_matrix_merged = confusion_matrix.sum(0) + recall_scores = self._compute_recall(confusion_matrix_merged) + elif self.averaging_method == AveragingMethod.WEIGHTED: + class_counts = class_counts.astype(np.float32) + recall_scores = np.average(recall_per_class, axis=0, weights=class_counts) + + return recall_scores, recall_per_class, unique_classes + + @staticmethod + def _match_detection_batch( + predictions_classes: np.ndarray, + target_classes: np.ndarray, + iou: np.ndarray, + iou_thresholds: np.ndarray, + ) -> np.ndarray: + num_predictions, num_iou_levels = ( + predictions_classes.shape[0], + iou_thresholds.shape[0], + ) + correct = np.zeros((num_predictions, num_iou_levels), dtype=bool) + correct_class = target_classes[:, None] == predictions_classes + + for i, iou_level in enumerate(iou_thresholds): + matched_indices = np.where((iou >= iou_level) & correct_class) + + if matched_indices[0].shape[0]: + combined_indices = np.stack(matched_indices, axis=1) + iou_values = iou[matched_indices][:, None] + matches = np.hstack([combined_indices, iou_values]) + + if matched_indices[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + + correct[matches[:, 1].astype(int), i] = True + + return correct + + @staticmethod + def _compute_confusion_matrix( + sorted_matches: np.ndarray, + sorted_prediction_class_ids: np.ndarray, + unique_classes: np.ndarray, + class_counts: np.ndarray, + ) -> np.ndarray: + """ + Compute the confusion matrix for each class and IoU threshold. + + Assumes the matches and prediction_class_ids are sorted by confidence + in descending order. + + Arguments: + sorted_matches: np.ndarray, bool, shape (P, Th), that is True + if the prediction is a true positive at the given IoU threshold. + sorted_prediction_class_ids: np.ndarray, int, shape (P,), containing + the class id for each prediction. + unique_classes: np.ndarray, int, shape (C,), containing the unique + class ids. + class_counts: np.ndarray, int, shape (C,), containing the number + of true instances for each class. + + Returns: + np.ndarray, shape (C, Th, 3), containing the true positives, false + positives, and false negatives for each class and IoU threshold. + """ + + num_thresholds = sorted_matches.shape[1] + num_classes = unique_classes.shape[0] + + confusion_matrix = np.zeros((num_classes, num_thresholds, 3)) + for class_idx, class_id in enumerate(unique_classes): + is_class = sorted_prediction_class_ids == class_id + num_true = class_counts[class_idx] + num_predictions = is_class.sum() + + if num_predictions == 0: + true_positives = np.zeros(num_thresholds) + false_positives = np.zeros(num_thresholds) + false_negatives = np.full(num_thresholds, num_true) + elif num_true == 0: + true_positives = np.zeros(num_thresholds) + false_positives = np.full(num_thresholds, num_predictions) + false_negatives = np.zeros(num_thresholds) + else: + true_positives = sorted_matches[is_class].sum(0) + false_positives = (1 - sorted_matches[is_class]).sum(0) + false_negatives = num_true - true_positives + confusion_matrix[class_idx] = np.stack( + [true_positives, false_positives, false_negatives], axis=1 + ) + + return confusion_matrix + + @staticmethod + def _compute_recall(confusion_matrix: np.ndarray) -> np.ndarray: + """ + Broadcastable function, computing the recall from the confusion matrix. + + Arguments: + confusion_matrix: np.ndarray, shape (N, ..., 3), where the last dimension + contains the true positives, false positives, and false negatives. + + Returns: + np.ndarray, shape (N, ...), containing the recall for each element. + """ + if not confusion_matrix.shape[-1] == 3: + raise ValueError( + f"Confusion matrix must have shape (..., 3), got " + f"{confusion_matrix.shape}" + ) + true_positives = confusion_matrix[..., 0] + false_negatives = confusion_matrix[..., 2] + + denominator = true_positives + false_negatives + recall = np.where(denominator == 0, 0, true_positives / denominator) + + return recall + + def _detections_content(self, detections: Detections) -> np.ndarray: + """Return boxes, masks or oriented bounding boxes from detections.""" + if self._metric_target == MetricTarget.BOXES: + return detections.xyxy + if self._metric_target == MetricTarget.MASKS: + return ( + detections.mask + if detections.mask is not None + else np.empty((0, 0, 0), dtype=bool) + ) + if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: + if obb := detections.data.get(ORIENTED_BOX_COORDINATES): + return np.ndarray(obb, dtype=np.float32) + return np.empty((0, 8), dtype=np.float32) + raise ValueError(f"Invalid metric target: {self._metric_target}") + + def _filter_detections_by_size( + self, detections: Detections, size_category: ObjectSizeCategory + ) -> Detections: + """Return a copy of detections with contents filtered by object size.""" + new_detections = deepcopy(detections) + if detections.is_empty() or size_category == ObjectSizeCategory.ANY: + return new_detections + + sizes = get_detection_size_category(new_detections, self._metric_target) + size_mask = sizes == size_category.value + + new_detections.xyxy = new_detections.xyxy[size_mask] + if new_detections.mask is not None: + new_detections.mask = new_detections.mask[size_mask] + if new_detections.class_id is not None: + new_detections.class_id = new_detections.class_id[size_mask] + if new_detections.confidence is not None: + new_detections.confidence = new_detections.confidence[size_mask] + if new_detections.tracker_id is not None: + new_detections.tracker_id = new_detections.tracker_id[size_mask] + if new_detections.data is not None: + for key, value in new_detections.data.items(): + new_detections.data[key] = np.array(value)[size_mask] + + return new_detections + + def _filter_predictions_and_targets_by_size( + self, + predictions_list: List[Detections], + targets_list: List[Detections], + size_category: ObjectSizeCategory, + ) -> Tuple[List[Detections], List[Detections]]: + """ + Filter predictions and targets by object size category. + """ + new_predictions_list = [] + new_targets_list = [] + for predictions, targets in zip(predictions_list, targets_list): + new_predictions_list.append( + self._filter_detections_by_size(predictions, size_category) + ) + new_targets_list.append( + self._filter_detections_by_size(targets, size_category) + ) + return new_predictions_list, new_targets_list + + +@dataclass +class RecallResult: + """ + The results of the recall metric calculation. + + Defaults to `0` if no detections or targets were provided. + Provides a custom `__str__` method for pretty printing. + + Attributes: + metric_target (MetricTarget): the type of data used for the metric - + boxes, masks or oriented bounding boxes. + averaging_method (AveragingMethod): the averaging method used to compute the + recall. Determines how the recall is aggregated across classes. + recall_at_50 (float): the recall at IoU threshold of `0.5`. + recall_at_75 (float): the recall at IoU threshold of `0.75`. + recall_scores (np.ndarray): the recall scores at each IoU threshold. + Shape: `(num_iou_thresholds,)` + recall_per_class (np.ndarray): the recall scores per class and IoU threshold. + Shape: `(num_target_classes, num_iou_thresholds)` + iou_thresholds (np.ndarray): the IoU thresholds used in the calculations. + matched_classes (np.ndarray): the class IDs of all matched classes. + Corresponds to the rows of `recall_per_class`. + small_objects (Optional[RecallResult]): the Recall metric results + for small objects. + medium_objects (Optional[RecallResult]): the Recall metric results + for medium objects. + large_objects (Optional[RecallResult]): the Recall metric results + for large objects. + """ + + metric_target: MetricTarget + averaging_method: AveragingMethod + + @property + def recall_at_50(self) -> float: + return self.recall_scores[0] + + @property + def recall_at_75(self) -> float: + return self.recall_scores[5] + + recall_scores: np.ndarray + recall_per_class: np.ndarray + iou_thresholds: np.ndarray + matched_classes: np.ndarray + + small_objects: Optional[RecallResult] + medium_objects: Optional[RecallResult] + large_objects: Optional[RecallResult] + + def __str__(self) -> str: + """ + Format as a pretty string. + + Example: + ```python + print(recall_result) + ``` + """ + out_str = ( + f"{self.__class__.__name__}:\n" + f"Metric target: {self.metric_target}\n" + f"Averaging method: {self.averaging_method}\n" + f"R @ 50: {self.recall_at_50:.4f}\n" + f"R @ 75: {self.recall_at_75:.4f}\n" + f"R @ thresh: {self.recall_scores}\n" + f"IoU thresh: {self.iou_thresholds}\n" + f"Recall per class:\n" + ) + if self.recall_per_class.size == 0: + out_str += " No results\n" + for class_id, recall_of_class in zip( + self.matched_classes, self.recall_per_class + ): + out_str += f" {class_id}: {recall_of_class}\n" + + indent = " " + if self.small_objects is not None: + indented = indent + str(self.small_objects).replace("\n", f"\n{indent}") + out_str += f"\nSmall objects:\n{indented}" + if self.medium_objects is not None: + indented = indent + str(self.medium_objects).replace("\n", f"\n{indent}") + out_str += f"\nMedium objects:\n{indented}" + if self.large_objects is not None: + indented = indent + str(self.large_objects).replace("\n", f"\n{indent}") + out_str += f"\nLarge objects:\n{indented}" + + return out_str + + def to_pandas(self) -> "pd.DataFrame": + """ + Convert the result to a pandas DataFrame. + + Returns: + (pd.DataFrame): The result as a DataFrame. + """ + ensure_pandas_installed() + import pandas as pd + + pandas_data = { + "R@50": self.recall_at_50, + "R@75": self.recall_at_75, + } + + if self.small_objects is not None: + small_objects_df = self.small_objects.to_pandas() + for key, value in small_objects_df.items(): + pandas_data[f"small_objects_{key}"] = value + if self.medium_objects is not None: + medium_objects_df = self.medium_objects.to_pandas() + for key, value in medium_objects_df.items(): + pandas_data[f"medium_objects_{key}"] = value + if self.large_objects is not None: + large_objects_df = self.large_objects.to_pandas() + for key, value in large_objects_df.items(): + pandas_data[f"large_objects_{key}"] = value + + return pd.DataFrame(pandas_data, index=[0]) + + def plot(self): + """ + Plot the recall results. + """ + + labels = ["Recall@50", "Recall@75"] + values = [self.recall_at_50, self.recall_at_75] + colors = [LEGACY_COLOR_PALETTE[0]] * 2 + + if self.small_objects is not None: + small_objects = self.small_objects + labels += ["Small: R@50", "Small: R@75"] + values += [small_objects.recall_at_50, small_objects.recall_at_75] + colors += [LEGACY_COLOR_PALETTE[3]] * 2 + + if self.medium_objects is not None: + medium_objects = self.medium_objects + labels += ["Medium: R@50", "Medium: R@75"] + values += [medium_objects.recall_at_50, medium_objects.recall_at_75] + colors += [LEGACY_COLOR_PALETTE[2]] * 2 + + if self.large_objects is not None: + large_objects = self.large_objects + labels += ["Large: R@50", "Large: R@75"] + values += [large_objects.recall_at_50, large_objects.recall_at_75] + colors += [LEGACY_COLOR_PALETTE[4]] * 2 + + plt.rcParams["font.family"] = "monospace" + + _, ax = plt.subplots(figsize=(10, 6)) + ax.set_ylim(0, 1) + ax.set_ylabel("Value", fontweight="bold") + title = ( + f"Recall, by Object Size" + f"\n(target: {self.metric_target.value}," + f" averaging: {self.averaging_method.value})" + ) + ax.set_title(title, fontweight="bold") + + x_positions = range(len(labels)) + bars = ax.bar(x_positions, values, color=colors, align="center") + + ax.set_xticks(x_positions) + ax.set_xticklabels(labels, rotation=45, ha="right") + + for bar in bars: + y_value = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2, + y_value + 0.02, + f"{y_value:.2f}", + ha="center", + va="bottom", + ) + + plt.rcParams["font.family"] = "sans-serif" + + plt.tight_layout() + plt.show() From fbd96d10c4c6358ddba4bf2bd4d34f498ebfde2a Mon Sep 17 00:00:00 2001 From: LinasKo Date: Fri, 18 Oct 2024 15:15:40 +0300 Subject: [PATCH 2/3] Add new and missing docstrings, examples to metrics, new Common section --- docs/metrics/common_values.md | 20 +++++++ docs/metrics/precision.md | 2 +- docs/metrics/recall.md | 2 +- mkdocs.yml | 1 + supervision/metrics/core.py | 26 +++++---- supervision/metrics/f1_score.py | 55 +++++++++++++++++- supervision/metrics/mean_average_precision.py | 44 ++++++++------ supervision/metrics/precision.py | 58 ++++++++++++++++++- supervision/metrics/recall.py | 58 ++++++++++++++++++- 9 files changed, 231 insertions(+), 35 deletions(-) create mode 100644 docs/metrics/common_values.md diff --git a/docs/metrics/common_values.md b/docs/metrics/common_values.md new file mode 100644 index 000000000..b7600f3f1 --- /dev/null +++ b/docs/metrics/common_values.md @@ -0,0 +1,20 @@ +--- +comments: true +status: new +--- + +# Common Values + +This page contains supplementary values, types and enums that metrics use. + + + +:::supervision.metrics.core.MetricTarget + + + +:::supervision.metrics.core.AveragingMethod diff --git a/docs/metrics/precision.md b/docs/metrics/precision.md index ca318f8fb..c704452ee 100644 --- a/docs/metrics/precision.md +++ b/docs/metrics/precision.md @@ -3,7 +3,7 @@ comments: true status: new --- -# F1 Score +# Precision

Precision

diff --git a/docs/metrics/recall.md b/docs/metrics/recall.md index 5baa4d3ee..78dde8334 100644 --- a/docs/metrics/recall.md +++ b/docs/metrics/recall.md @@ -3,7 +3,7 @@ comments: true status: new --- -# F1 Score +# Recall

Recall

diff --git a/mkdocs.yml b/mkdocs.yml index a3c9c1caa..b30dbcfcc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -69,6 +69,7 @@ nav: - Precision: metrics/precision.md - Recall: metrics/recall.md - F1 Score: metrics/f1_score.md + - Common Values: metrics/common_values.md - Legacy Metrics: detection/metrics.md - Utils: - Video: utils/video.md diff --git a/supervision/metrics/core.py b/supervision/metrics/core.py index d1818441e..def5999a0 100644 --- a/supervision/metrics/core.py +++ b/supervision/metrics/core.py @@ -37,9 +37,10 @@ class MetricTarget(Enum): """ Specifies what type of detection is used to compute the metric. - * BOXES: xyxy bounding boxes - * MASKS: Binary masks - * ORIENTED_BOUNDING_BOXES: Oriented bounding boxes (OBB) + Attributes: + BOXES: xyxy bounding boxes + MASKS: Binary masks + ORIENTED_BOUNDING_BOXES: Oriented bounding boxes (OBB) """ BOXES = "boxes" @@ -54,15 +55,16 @@ class AveragingMethod(Enum): Suppose, before returning the final result, a metric is computed for each class. How do you combine those to get the final number? - * MACRO: Calculate the metric for each class and average the results. The simplest - averaging method, but it does not take class imbalance into account. - * MICRO: Calculate the metric globally by counting the total true positives, false - positives, and false negatives. Micro averaging is useful when you want to give - more importance to classes with more samples. It's also more appropriate if you - have an imbalance in the number of instances per class. - * WEIGHTED: Calculate the metric for each class and average the results, weighted by - the number of true instances of each class. Use weighted averaging if you want - to take class imbalance into account. + Attributes: + MACRO: Calculate the metric for each class and average the results. The simplest + averaging method, but it does not take class imbalance into account. + MICRO: Calculate the metric globally by counting the total true positives, false + positives, and false negatives. Micro averaging is useful when you want to + give more importance to classes with more samples. It's also more + appropriate if you have an imbalance in the number of instances per class. + WEIGHTED: Calculate the metric for each class and average the results, weighted + by the number of true instances of each class. Use weighted averaging if + you want to take class imbalance into account. """ MACRO = "macro" diff --git a/supervision/metrics/f1_score.py b/supervision/metrics/f1_score.py index 2ca5bca5c..ba4fcd59a 100644 --- a/supervision/metrics/f1_score.py +++ b/supervision/metrics/f1_score.py @@ -23,11 +23,45 @@ class F1Score(Metric): + """ + F1 Score is a metric used to evaluate object detection models. It is the harmonic + mean of precision and recall, calculated at different IoU thresholds. + + In simple terms, F1 Score is a measure of a model's balance between precision and + recall (accuracy and completeness), calculated as: + + `F1 = 2 * (precision * recall) / (precision + recall)` + + Example: + ```python + import supervision as sv + from supervision.metrics import F1Score + + predictions = sv.Detections(...) + targets = sv.Detections(...) + + f1_metric = F1Score() + f1_result = f1_metric.update(predictions, targets).compute() + + print(f1_result) + print(f1_result.f1_50) + print(f1_result.small_objects.f1_50) + ``` + """ + def __init__( self, metric_target: MetricTarget = MetricTarget.BOXES, averaging_method: AveragingMethod = AveragingMethod.WEIGHTED, ): + """ + Initialize the F1Score metric. + + Args: + metric_target (MetricTarget): The type of detection data to use. + averaging_method (AveragingMethod): The averaging method used to compute the + F1 scores. Determines how the F1 scores are aggregated across classes. + """ self._metric_target = metric_target if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: raise NotImplementedError( @@ -40,6 +74,9 @@ def __init__( self._targets_list: List[Detections] = [] def reset(self) -> None: + """ + Reset the metric to its initial state, clearing all stored data. + """ self._predictions_list = [] self._targets_list = [] @@ -48,6 +85,16 @@ def update( predictions: Union[Detections, List[Detections]], targets: Union[Detections, List[Detections]], ) -> F1Score: + """ + Add new predictions and targets to the metric, but do not compute the result. + + Args: + predictions (Union[Detections, List[Detections]]): The predicted detections. + targets (Union[Detections, List[Detections]]): The target detections. + + Returns: + (F1Score): The updated metric instance. + """ if not isinstance(predictions, list): predictions = [predictions] if not isinstance(targets, list): @@ -65,6 +112,13 @@ def update( return self def compute(self) -> F1ScoreResult: + """ + Calculate the F1 score metric based on the stored predictions and ground-truth + data, at different IoU thresholds. + + Returns: + (F1ScoreResult): The F1 score metric result. + """ result = self._compute(self._predictions_list, self._targets_list) small_predictions, small_targets = self._filter_predictions_and_targets_by_size( @@ -373,7 +427,6 @@ class F1ScoreResult: The results of the F1 score metric calculation. Defaults to `0` if no detections or targets were provided. - Provides a custom `__str__` method for pretty printing. Attributes: metric_target (MetricTarget): the type of data used for the metric - diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index dbd60b2e7..8cec50c85 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -23,6 +23,27 @@ class MeanAveragePrecision(Metric): + """ + Mean Average Precision (mAP) is a metric used to evaluate object detection models. + It is the average of the precision-recall curves at different IoU thresholds. + + Example: + ```python + import supervision as sv + from supervision.metrics import MeanAveragePrecision + + predictions = sv.Detections(...) + targets = sv.Detections(...) + + map_metric = MeanAveragePrecision() + map_result = map_metric.update(predictions, targets).compute() + + print(map_result) + print(map_result.map50_95) + map_result.plot() + ``` + """ + def __init__( self, metric_target: MetricTarget = MetricTarget.BOXES, @@ -47,6 +68,9 @@ def __init__( self._targets_list: List[Detections] = [] def reset(self) -> None: + """ + Reset the metric to its initial state, clearing all stored data. + """ self._predictions_list = [] self._targets_list = [] @@ -95,26 +119,10 @@ def compute( ) -> MeanAveragePrecisionResult: """ Calculate Mean Average Precision based on predicted and ground-truth - detections at different thresholds. + detections at different thresholds. Returns: - (MeanAveragePrecisionResult): New instance of MeanAveragePrecision. - - Example: - ```python - import supervision as sv - from supervision.metrics import MeanAveragePrecision - - predictions = sv.Detections(...) - targets = sv.Detections(...) - - map_metric = MeanAveragePrecision() - map_result = map_metric.update(predictions, targets).compute() - - print(map_result) - print(map_result.map50_95) - map_result.plot() - ``` + (MeanAveragePrecisionResult): The Mean Average Precision result. """ result = self._compute(self._predictions_list, self._targets_list) diff --git a/supervision/metrics/precision.py b/supervision/metrics/precision.py index ba441831a..d915e1f49 100644 --- a/supervision/metrics/precision.py +++ b/supervision/metrics/precision.py @@ -23,11 +23,48 @@ class Precision(Metric): + """ + Precision is a metric used to evaluate object detection models. It is the ratio of + true positive detections to the total number of predicted detections. We calculate + it at different IoU thresholds. + + In simple terms, Precision is a measure of a model's accuracy, calculated as: + + `Precision = TP / (TP + FP)` + + Here, `TP` is the number of true positives (correct detections), and `FP` is the + number of false positive detections (detected, but incorrectly). + + Example: + ```python + import supervision as sv + from supervision.metrics import Precision + + predictions = sv.Detections(...) + targets = sv.Detections(...) + + precision_metric = Precision() + precision_result = precision_metric.update(predictions, targets).compute() + + print(precision_result) + print(precision_result.precision_at_50) + print(precision_result.small_objects.precision_at_50) + ``` + """ + def __init__( self, metric_target: MetricTarget = MetricTarget.BOXES, averaging_method: AveragingMethod = AveragingMethod.WEIGHTED, ): + """ + Initialize the Precision metric. + + Args: + metric_target (MetricTarget): The type of detection data to use. + averaging_method (AveragingMethod): The averaging method used to compute the + precision. Determines how the precision is aggregated across classes. + """ self._metric_target = metric_target if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: raise NotImplementedError( @@ -40,6 +77,9 @@ def __init__( self._targets_list: List[Detections] = [] def reset(self) -> None: + """ + Reset the metric to its initial state, clearing all stored data. + """ self._predictions_list = [] self._targets_list = [] @@ -48,6 +88,16 @@ def update( predictions: Union[Detections, List[Detections]], targets: Union[Detections, List[Detections]], ) -> Precision: + """ + Add new predictions and targets to the metric, but do not compute the result. + + Args: + predictions (Union[Detections, List[Detections]]): The predicted detections. + targets (Union[Detections, List[Detections]]): The target detections. + + Returns: + (Precision): The updated metric instance. + """ if not isinstance(predictions, list): predictions = [predictions] if not isinstance(targets, list): @@ -65,6 +115,13 @@ def update( return self def compute(self) -> PrecisionResult: + """ + Calculate the precision metric based on the stored predictions and ground-truth + data, at different IoU thresholds. + + Returns: + (PrecisionResult): The precision metric result. + """ result = self._compute(self._predictions_list, self._targets_list) small_predictions, small_targets = self._filter_predictions_and_targets_by_size( @@ -373,7 +430,6 @@ class PrecisionResult: The results of the precision metric calculation. Defaults to `0` if no detections or targets were provided. - Provides a custom `__str__` method for pretty printing. Attributes: metric_target (MetricTarget): the type of data used for the metric - diff --git a/supervision/metrics/recall.py b/supervision/metrics/recall.py index 7c90859cc..9eae24f8e 100644 --- a/supervision/metrics/recall.py +++ b/supervision/metrics/recall.py @@ -23,11 +23,48 @@ class Recall(Metric): + """ + Recall is a metric used to evaluate object detection models. It is the ratio of + true positive detections to the total number of ground truth instances. We calculate + it at different IoU thresholds. + + In simple terms, Recall is a measure of a model's completeness, calculated as: + + `Recall = TP / (TP + FN)` + + Here, `TP` is the number of true positives (correct detections), and `FN` is the + number of false negatives (missed detections). + + Example: + ```python + import supervision as sv + from supervision.metrics import Recall + + predictions = sv.Detections(...) + targets = sv.Detections(...) + + recall_metric = Recall() + recall_result = recall_metric.update(predictions, targets).compute() + + print(recall_result) + print(recall_result.recall_at_50) + print(recall_result.small_objects.recall_at_50) + ``` + """ + def __init__( self, metric_target: MetricTarget = MetricTarget.BOXES, averaging_method: AveragingMethod = AveragingMethod.WEIGHTED, ): + """ + Initialize the Recall metric. + + Args: + metric_target (MetricTarget): The type of detection data to use. + averaging_method (AveragingMethod): The averaging method used to compute the + recall. Determines how the recall is aggregated across classes. + """ self._metric_target = metric_target if self._metric_target == MetricTarget.ORIENTED_BOUNDING_BOXES: raise NotImplementedError( @@ -40,6 +77,9 @@ def __init__( self._targets_list: List[Detections] = [] def reset(self) -> None: + """ + Reset the metric to its initial state, clearing all stored data. + """ self._predictions_list = [] self._targets_list = [] @@ -48,6 +88,16 @@ def update( predictions: Union[Detections, List[Detections]], targets: Union[Detections, List[Detections]], ) -> Recall: + """ + Add new predictions and targets to the metric, but do not compute the result. + + Args: + predictions (Union[Detections, List[Detections]]): The predicted detections. + targets (Union[Detections, List[Detections]]): The target detections. + + Returns: + (Recall): The updated metric instance. + """ if not isinstance(predictions, list): predictions = [predictions] if not isinstance(targets, list): @@ -65,6 +115,13 @@ def update( return self def compute(self) -> RecallResult: + """ + Calculate the precision metric based on the stored predictions and ground-truth + data, at different IoU thresholds. + + Returns: + (RecallResult): The precision metric result. + """ result = self._compute(self._predictions_list, self._targets_list) small_predictions, small_targets = self._filter_predictions_and_targets_by_size( @@ -371,7 +428,6 @@ class RecallResult: The results of the recall metric calculation. Defaults to `0` if no detections or targets were provided. - Provides a custom `__str__` method for pretty printing. Attributes: metric_target (MetricTarget): the type of data used for the metric - From 3e8a88a8d4ba3d31850788af839e469410a86961 Mon Sep 17 00:00:00 2001 From: LinasKo Date: Fri, 18 Oct 2024 15:28:45 +0300 Subject: [PATCH 3/3] Add Precision and Recall to metrics __init__ --- supervision/metrics/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/supervision/metrics/__init__.py b/supervision/metrics/__init__.py index 8ae33e639..90fc17b47 100644 --- a/supervision/metrics/__init__.py +++ b/supervision/metrics/__init__.py @@ -8,6 +8,8 @@ MeanAveragePrecision, MeanAveragePrecisionResult, ) +from supervision.metrics.precision import Precision, PrecisionResult +from supervision.metrics.recall import Recall, RecallResult from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category,