diff --git a/docs/en/api/metrics.rst b/docs/en/api/metrics.rst index 0e9d75fd..f1a36737 100644 --- a/docs/en/api/metrics.rst +++ b/docs/en/api/metrics.rst @@ -31,6 +31,7 @@ Metrics OIDMeanAP F1Score HmeanIoU + InstanceSeg EndPointError PCKAccuracy MpiiPCKAccuracy diff --git a/docs/zh_cn/api/metrics.rst b/docs/zh_cn/api/metrics.rst index 0e9d75fd..f1a36737 100644 --- a/docs/zh_cn/api/metrics.rst +++ b/docs/zh_cn/api/metrics.rst @@ -31,6 +31,7 @@ Metrics OIDMeanAP F1Score HmeanIoU + InstanceSeg EndPointError PCKAccuracy MpiiPCKAccuracy diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 171e7196..657b8e9b 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -12,6 +12,7 @@ from .f1_score import F1Score from .gradient_error import GradientError from .hmean_iou import HmeanIoU +from .instance_seg import InstanceSeg from .mae import MeanAbsoluteError from .matting_mse import MattingMeanSquaredError from .mean_iou import MeanIoU @@ -34,7 +35,7 @@ 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall', 'PeakSignalNoiseRatio', 'MeanAbsoluteError', 'MeanSquaredError', 'StructuralSimilarity', 'SignalNoiseRatio', 'MultiLabelMetric', - 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', + 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'InstanceSeg', 'DOTAMeanAP', 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', 'ConnectivityError', 'ROUGE' ] diff --git a/mmeval/metrics/_vendor/scannet/README.md b/mmeval/metrics/_vendor/scannet/README.md new file mode 100644 index 00000000..46134b19 --- /dev/null +++ b/mmeval/metrics/_vendor/scannet/README.md @@ -0,0 +1,2 @@ +The code under this folder is from the official [ScanNet repo](https://github.com/ScanNet/ScanNet). +Some unused codes are removed to minimize the length of codes added. diff --git a/mmeval/metrics/_vendor/scannet/__init__.py b/mmeval/metrics/_vendor/scannet/__init__.py new file mode 100644 index 00000000..812196a3 --- /dev/null +++ b/mmeval/metrics/_vendor/scannet/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .evaluate_semantic_instance import scannet_eval + +__all__ = ['scannet_eval'] diff --git a/mmeval/metrics/_vendor/scannet/evaluate_semantic_instance.py b/mmeval/metrics/_vendor/scannet/evaluate_semantic_instance.py new file mode 100644 index 00000000..397f863d --- /dev/null +++ b/mmeval/metrics/_vendor/scannet/evaluate_semantic_instance.py @@ -0,0 +1,346 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# adapted from https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_instance.py # noqa +import numpy as np +from copy import deepcopy + +from . import util_3d + + +def evaluate_matches(matches, class_labels, options): + """Evaluate instance segmentation from matched gt and predicted instances + for all scenes. + + Args: + matches (dict): Contains gt2pred and pred2gt infos for every scene. + class_labels (tuple[str]): Class names. + options (dict): ScanNet evaluator options. See get_options. + + Returns: + np.array: Average precision scores for all thresholds and categories. + """ + overlaps = options['overlaps'] + min_region_sizes = [options['min_region_sizes'][0]] + dist_threshes = [options['distance_threshes'][0]] + dist_confs = [options['distance_confs'][0]] + + # results: class x overlap + ap = np.zeros((len(dist_threshes), len(class_labels), len(overlaps)), + float) + for di, (min_region_size, distance_thresh, distance_conf) in enumerate( + zip(min_region_sizes, dist_threshes, dist_confs)): + for oi, overlap_th in enumerate(overlaps): + pred_visited = {} + for m in matches: + for label_name in class_labels: + for p in matches[m]['pred'][label_name]: + if 'filename' in p: + pred_visited[p['filename']] = False + for li, label_name in enumerate(class_labels): + y_true = np.empty(0) + y_score = np.empty(0) + hard_false_negatives = 0 + has_gt = False + has_pred = False + for m in matches: + pred_instances = matches[m]['pred'][label_name] + gt_instances = matches[m]['gt'][label_name] + # filter groups in ground truth + gt_instances = [ + gt for gt in gt_instances + if gt['instance_id'] >= 1000 and gt['vert_count'] >= + min_region_size and gt['med_dist'] <= distance_thresh + and gt['dist_conf'] >= distance_conf + ] + if gt_instances: + has_gt = True + if pred_instances: + has_pred = True + + cur_true = np.ones(len(gt_instances)) + cur_score = np.ones(len(gt_instances)) * (-float('inf')) + cur_match = np.zeros(len(gt_instances), dtype=bool) + # collect matches + for (gti, gt) in enumerate(gt_instances): + found_match = False + for pred in gt['matched_pred']: + # greedy assignments + if pred_visited[pred['filename']]: + continue + overlap = float(pred['intersection']) / ( + gt['vert_count'] + pred['vert_count'] - + pred['intersection']) + if overlap > overlap_th: + confidence = pred['confidence'] + # if already have a prediction for this gt, + # the prediction with the lower score is automatically a false positive # noqa + if cur_match[gti]: + max_score = max(cur_score[gti], confidence) + min_score = min(cur_score[gti], confidence) + cur_score[gti] = max_score + # append false positive + cur_true = np.append(cur_true, 0) + cur_score = np.append(cur_score, min_score) + cur_match = np.append(cur_match, True) + # otherwise set score + else: + found_match = True + cur_match[gti] = True + cur_score[gti] = confidence + pred_visited[pred['filename']] = True + if not found_match: + hard_false_negatives += 1 + # remove non-matched ground truth instances + cur_true = cur_true[cur_match] + cur_score = cur_score[cur_match] + + # collect non-matched predictions as false positive + for pred in pred_instances: + found_gt = False + for gt in pred['matched_gt']: + overlap = float(gt['intersection']) / ( + gt['vert_count'] + pred['vert_count'] - + gt['intersection']) + if overlap > overlap_th: + found_gt = True + break + if not found_gt: + num_ignore = pred['void_intersection'] + for gt in pred['matched_gt']: + # group? + if gt['instance_id'] < 1000: + num_ignore += gt['intersection'] + # small ground truth instances + if gt['vert_count'] < min_region_size or gt[ + 'med_dist'] > distance_thresh or gt[ + 'dist_conf'] < distance_conf: + num_ignore += gt['intersection'] + proportion_ignore = float( + num_ignore) / pred['vert_count'] + # if not ignored append false positive + if proportion_ignore <= overlap_th: + cur_true = np.append(cur_true, 0) + confidence = pred['confidence'] + cur_score = np.append(cur_score, confidence) + + # append to overall results + y_true = np.append(y_true, cur_true) + y_score = np.append(y_score, cur_score) + + # compute average precision + if has_gt and has_pred: + # compute precision recall curve first + + # sorting and cumsum + score_arg_sort = np.argsort(y_score) + y_score_sorted = y_score[score_arg_sort] + y_true_sorted = y_true[score_arg_sort] + y_true_sorted_cumsum = np.cumsum(y_true_sorted) + + # unique thresholds + (thresholds, unique_indices) = np.unique( + y_score_sorted, return_index=True) + num_prec_recall = len(unique_indices) + 1 + + # prepare precision recall + num_examples = len(y_score_sorted) + # follow https://github.com/ScanNet/ScanNet/pull/26 ? # noqa + num_true_examples = y_true_sorted_cumsum[-1] if len( + y_true_sorted_cumsum) > 0 else 0 + precision = np.zeros(num_prec_recall) + recall = np.zeros(num_prec_recall) + + # deal with the first point + y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0) + # deal with remaining + for idx_res, idx_scores in enumerate(unique_indices): + cumsum = y_true_sorted_cumsum[idx_scores - 1] + tp = num_true_examples - cumsum + fp = num_examples - idx_scores - tp + fn = cumsum + hard_false_negatives + p = float(tp) / (tp + fp) + r = float(tp) / (tp + fn) + precision[idx_res] = p + recall[idx_res] = r + + # first point in curve is artificial + precision[-1] = 1. + recall[-1] = 0. + + # compute average of precision-recall curve + recall_for_conv = np.copy(recall) + recall_for_conv = np.append(recall_for_conv[0], + recall_for_conv) + recall_for_conv = np.append(recall_for_conv, 0.) + + stepWidths = np.convolve(recall_for_conv, [-0.5, 0, 0.5], + 'valid') + # integrate is now simply a dot product + ap_current = np.dot(precision, stepWidths) + + elif has_gt: + ap_current = 0.0 + else: + ap_current = float('nan') + ap[di, li, oi] = ap_current + return ap + + +def compute_averages(aps, options, class_labels): + """Averages AP scores for all categories. + + Args: + aps (np.array): AP scores for all thresholds and categories. + options (dict): ScanNet evaluator options. See get_options. + class_labels (tuple[str]): Class names. + + Returns: + dict: Overall and per-category AP scores. + """ + d_inf = 0 + o50 = np.where(np.isclose(options['overlaps'], 0.5)) + o25 = np.where(np.isclose(options['overlaps'], 0.25)) + o_all_but25 = np.where( + np.logical_not(np.isclose(options['overlaps'], 0.25))) + avg_dict = {} + avg_dict['all_ap'] = np.nanmean(aps[d_inf, :, o_all_but25]) + avg_dict['all_ap_50%'] = np.nanmean(aps[d_inf, :, o50]) + avg_dict['all_ap_25%'] = np.nanmean(aps[d_inf, :, o25]) + avg_dict['classes'] = {} + for (li, label_name) in enumerate(class_labels): + avg_dict['classes'][label_name] = {} + avg_dict['classes'][label_name]['ap'] = np.average(aps[d_inf, li, + o_all_but25]) + avg_dict['classes'][label_name]['ap50%'] = np.average(aps[d_inf, li, + o50]) + avg_dict['classes'][label_name]['ap25%'] = np.average(aps[d_inf, li, + o25]) + return avg_dict + + +def assign_instances_for_scan(pred_info, gt_ids, options, valid_class_ids, + class_labels, id_to_label): + """Assign gt and predicted instances for a single scene. + + Args: + pred_info (dict): Predicted masks, labels and scores. + gt_ids (np.array): Ground truth instance masks. + options (dict): ScanNet evaluator options. See get_options. + valid_class_ids (tuple[int]): Ids of valid categories. + class_labels (tuple[str]): Class names. + id_to_label (dict[int, str]): Mapping of valid class id to class label. + + Returns: + dict: Per class assigned gt to predicted instances. + dict: Per class assigned predicted to gt instances. + """ + # get gt instances + gt_instances = util_3d.get_instances(gt_ids, valid_class_ids, class_labels, + id_to_label) + # associate + gt2pred = deepcopy(gt_instances) + for label in gt2pred: + for gt in gt2pred[label]: + gt['matched_pred'] = [] + pred2gt = {} + for label in class_labels: + pred2gt[label] = [] + num_pred_instances = 0 + # mask of void labels in the ground truth + bool_void = np.logical_not(np.in1d(gt_ids // 1000, valid_class_ids)) + # go through all prediction masks + for pred_mask_file in pred_info: + label_id = int(pred_info[pred_mask_file]['label_id']) + conf = pred_info[pred_mask_file]['conf'] + if not label_id in id_to_label: # noqa E713 + continue + label_name = id_to_label[label_id] + # read the mask + pred_mask = pred_info[pred_mask_file]['mask'] + if len(pred_mask) != len(gt_ids): + raise ValueError('len(pred_mask) != len(gt_ids)') + # convert to binary + pred_mask = np.not_equal(pred_mask, 0) + num = np.count_nonzero(pred_mask) + if num < options['min_region_sizes'][0]: + continue # skip if empty + + pred_instance = {} + pred_instance['filename'] = pred_mask_file + pred_instance['pred_id'] = num_pred_instances + pred_instance['label_id'] = label_id + pred_instance['vert_count'] = num + pred_instance['confidence'] = conf + pred_instance['void_intersection'] = np.count_nonzero( + np.logical_and(bool_void, pred_mask)) + + # matched gt instances + matched_gt = [] + # go through all gt instances with matching label + for (gt_num, gt_inst) in enumerate(gt2pred[label_name]): + intersection = np.count_nonzero( + np.logical_and(gt_ids == gt_inst['instance_id'], pred_mask)) + if intersection > 0: + gt_copy = gt_inst.copy() + pred_copy = pred_instance.copy() + gt_copy['intersection'] = intersection + pred_copy['intersection'] = intersection + matched_gt.append(gt_copy) + gt2pred[label_name][gt_num]['matched_pred'].append(pred_copy) + pred_instance['matched_gt'] = matched_gt + num_pred_instances += 1 + pred2gt[label_name].append(pred_instance) + + return gt2pred, pred2gt + + +def scannet_eval(preds, gts, options, valid_class_ids, class_labels, + id_to_label): + """Evaluate instance segmentation in ScanNet protocol. + + Args: + preds (list[dict]): Per scene predictions of mask, label and + confidence. + gts (list[np.array]): Per scene ground truth instance masks. + options (dict): ScanNet evaluator options. See get_options. + valid_class_ids (tuple[int]): Ids of valid categories. + class_labels (tuple[str]): Class names. + id_to_label (dict[int, str]): Mapping of valid class id to class label. + + Returns: + dict: Overall and per-category AP scores. + """ + options = get_options(options) + matches = {} + for i, (pred, gt) in enumerate(zip(preds, gts)): + matches_key = i + # assign gt to predictions + gt2pred, pred2gt = assign_instances_for_scan(pred, gt, options, + valid_class_ids, + class_labels, id_to_label) + matches[matches_key] = {} + matches[matches_key]['gt'] = gt2pred + matches[matches_key]['pred'] = pred2gt + + ap_scores = evaluate_matches(matches, class_labels, options) + avgs = compute_averages(ap_scores, options, class_labels) + return avgs + + +def get_options(options=None): + """Set ScanNet evaluator options. + + Args: + options (dict, optional): Not default options. Default: None. + + Returns: + dict: Updated options with all 4 keys. + """ + assert options is None or isinstance(options, dict) + _options = dict( + overlaps=np.append(np.arange(0.5, 0.95, 0.05), 0.25), + min_region_sizes=np.array([100]), + distance_threshes=np.array([float('inf')]), + distance_confs=np.array([-float('inf')])) + if options is not None: + _options.update(options) + return _options diff --git a/mmeval/metrics/_vendor/scannet/util_3d.py b/mmeval/metrics/_vendor/scannet/util_3d.py new file mode 100644 index 00000000..fd9291de --- /dev/null +++ b/mmeval/metrics/_vendor/scannet/util_3d.py @@ -0,0 +1,83 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# adapted from https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/util_3d.py # noqa +import json +import numpy as np + + +class Instance: + """Single instance for ScanNet evaluator. + + Args: + mesh_vert_instances (np.array): Instance ids for each point. + instance_id: Id of single instance. + """ + instance_id = 0 + label_id = 0 + vert_count = 0 + med_dist = -1 + dist_conf = 0.0 + + def __init__(self, mesh_vert_instances, instance_id): + if instance_id == -1: + return + self.instance_id = int(instance_id) + self.label_id = int(self.get_label_id(instance_id)) + self.vert_count = int( + self.get_instance_verts(mesh_vert_instances, instance_id)) + + @staticmethod + def get_label_id(instance_id): + return int(instance_id // 1000) + + @staticmethod + def get_instance_verts(mesh_vert_instances, instance_id): + return (mesh_vert_instances == instance_id).sum() + + def to_json(self): + return json.dumps( + self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + def to_dict(self): + dict = {} + dict['instance_id'] = self.instance_id + dict['label_id'] = self.label_id + dict['vert_count'] = self.vert_count + dict['med_dist'] = self.med_dist + dict['dist_conf'] = self.dist_conf + return dict + + def from_json(self, data): + self.instance_id = int(data['instance_id']) + self.label_id = int(data['label_id']) + self.vert_count = int(data['vert_count']) + if 'med_dist' in data: + self.med_dist = float(data['med_dist']) + self.dist_conf = float(data['dist_conf']) + + def __str__(self): + return '(' + str(self.instance_id) + ')' + + +def get_instances(ids, class_ids, class_labels, id2label): + """Transform gt instance mask to Instance objects. + + Args: + ids (np.array): Instance ids for each point. + class_ids: (tuple[int]): Ids of valid categories. + class_labels (tuple[str]): Class names. + id2label: (dict[int, str]): Mapping of valid class id to class label. + + Returns: + dict [str, list]: Instance objects grouped by class label. + """ + instances = {} + for label in class_labels: + instances[label] = [] + instance_ids = np.unique(ids) + for id in instance_ids: + if id == 0: + continue + inst = Instance(ids, id) + if inst.label_id in class_ids: + instances[id2label[inst.label_id]].append(inst.to_dict()) + return instances diff --git a/mmeval/metrics/instance_seg.py b/mmeval/metrics/instance_seg.py new file mode 100644 index 00000000..4e591601 --- /dev/null +++ b/mmeval/metrics/instance_seg.py @@ -0,0 +1,192 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Dict, List, Optional, Sequence + +from mmeval.core.base_metric import BaseMetric +from mmeval.metrics._vendor.scannet import scannet_eval + + +class InstanceSeg(BaseMetric): + """3D instance segmentation evaluation metric. + + This metric is for ScanNet 3D instance segmentation tasks. For more info + about ScanNet, please read [here](https://github.com/ScanNet/ScanNet). + + Args: + dataset_meta (dict, optional): Provide dataset meta information. + classes (List[str], optional): Provide dataset classes information as + an alternative to dataset_meta. + valid_class_ids (List[int], optional): Provide dataset valid class ids + information as an alternative to dataset_meta. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + + Example: + >>> import numpy as np + >>> from mmeval import InstanceSegMetric + >>> seg_valid_class_ids = (3, 4, 5) + >>> class_labels = ('cabinet', 'bed', 'chair') + >>> dataset_meta = dict( + ... seg_valid_class_ids=seg_valid_class_ids, classes=class_labels) + >>> + >>> def _demo_mm_model_output(): + >>> n_points_list = [3300, 3000] + >>> gt_labels_list = [[0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 0], + >>> [1, 1, 2, 1, 2, 2, 0, 0, 0, 0, 1]] + >>> predictions = [] + >>> groundtruths = [] + >>> + >>> for idx, points_num in enumerate(n_points_list): + >>> points = np.ones(points_num) * -1 + >>> gt = np.ones(points_num) + >>> info = {} + >>> for ii, i in enumerate(gt_labels_list[idx]): + >>> i = seg_valid_class_ids[i] + >>> points[ii * 300:(ii + 1) * 300] = ii + >>> gt[ii * 300:(ii + 1) * 300] = i * 1000 + ii + >>> info[f"{idx}_{ii}"] = { + >>> 'mask': (points == ii), + >>> 'label_id': i, + >>> 'conf': 0.99 + >>> } + >>> predictions.append(info) + >>> groundtruths.append(gt) + >>> + >>> return predictions, groundtruths + >>> + >>> instance_seg_metric = InstanceSegMetric(dataset_meta=dataset_meta) + >>> res = instance_seg_metric(predictions, groundtruths) + >>> res + { + 'all_ap': 1.0, + 'all_ap_50%': 1.0, + 'all_ap_25%': 1.0, + 'classes': { + 'cabinet': { + 'ap': 1.0, + 'ap50%': 1.0, + 'ap25%': 1.0 + }, + 'bed': { + 'ap': 0.5, + 'ap50%': 0.5, + 'ap25%': 0.5 + }, + 'chair': { + 'ap': 1.0, + 'ap50%': 1.0, + 'ap25%': 1.0 + } + } + } + """ + + def __init__(self, + classes: Optional[List[str]] = None, + valid_class_ids: Optional[List[int]] = None, + **kwargs): + super().__init__(**kwargs) + self._valid_class_ids = valid_class_ids + self._classes = classes + + @property + def classes(self): + """Returns classes. + + The classes should be set during initialization, otherwise it will + be obtained from the 'classes' field in ``self.dataset_meta``. + + Raises: + RuntimeError: If the classes is not set. + + Returns: + List[str]: The classes. + """ + if self._classes is not None: + return self._classes + + if self.dataset_meta and 'classes' in self.dataset_meta: + self._classes = self.dataset_meta['classes'] + else: + raise RuntimeError('The `classes` is required, and not found in ' + f'dataset_meta: {self.dataset_meta}') + return self._classes + + @property + def valid_class_ids(self): + """Returns valid class ids. + + The valid class ids should be set during initialization, otherwise + it will be obtained from the 'seg_valid_class_ids' field in + ``self.dataset_meta``. + + Raises: + RuntimeError: If valid class ids is not set. + + Returns: + List[str]: The valid class ids. + """ + if self._valid_class_ids is not None: + return self._valid_class_ids + + if self.dataset_meta and 'seg_valid_class_ids' in self.dataset_meta: + self._valid_class_ids = self.dataset_meta['seg_valid_class_ids'] + else: + raise RuntimeError( + 'The `seg_valid_class_ids` is required, and not found in ' + f'dataset_meta: {self.dataset_meta}') + return self._valid_class_ids + + def add(self, predictions: Sequence[Dict], groundtruths: Sequence[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Process one batch of data samples and predictions. + + The processed results should be stored in ``self.results``, + which will be used to compute the metrics when all batches + have been processed. + + Args: + predictions (Sequence[Dict]): A sequence of dict. Each dict + representing a detection result. The dict has multiple keys, + each key represents the name of the instance, and the value + is also a dict, with following keys: + + - mask (array): Predicted instance masks. + - label_id (int): Predicted instance labels. + - conf (float): Predicted instance scores. + + groundtruths (Sequence[array]): A sequence of array. Each array + represents a groundtruths for an image. + """ + for prediction, groundtruth in zip(predictions, groundtruths): + self._results.append((deepcopy(prediction), deepcopy(groundtruth))) + + def compute_metric(self, results: List[List[Dict]]) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + preds = [] + gts = [] + for pred, gt in results: + preds.append(pred) + gts.append(gt) + + assert len(self.valid_class_ids) == len(self.classes) + id_to_label = { + self.valid_class_ids[i]: self.classes[i] + for i in range(len(self.valid_class_ids)) + } + + metrics = scannet_eval( + preds=preds, + gts=gts, + options=None, + valid_class_ids=self.valid_class_ids, + class_labels=self.classes, + id_to_label=id_to_label) + + return metrics diff --git a/tests/test_metrics/test_instance_seg.py b/tests/test_metrics/test_instance_seg.py new file mode 100644 index 00000000..3311cfdb --- /dev/null +++ b/tests/test_metrics/test_instance_seg.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from mmeval.metrics import InstanceSeg + +seg_valid_class_ids = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, + 36, 39) +class_labels = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', + 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', + 'garbagebin') +dataset_meta = dict( + seg_valid_class_ids=seg_valid_class_ids, classes=class_labels) + + +def _demo_mm_model_output(): + """Create a superset of inputs needed to run test or train batches.""" + n_points_list = [3300, 3000] + gt_labels_list = [[0, 0, 0, 0, 0, 0, 14, 14, 2, 2, 2], + [13, 13, 2, 1, 3, 3, 0, 0, 0, 0]] + predictions = [] + groundtruths = [] + + for idx, points_num in enumerate(n_points_list): + points = np.ones(points_num) * -1 + gt = np.ones(points_num) + info = {} + for ii, i in enumerate(gt_labels_list[idx]): + i = seg_valid_class_ids[i] + points[ii * 300:(ii + 1) * 300] = ii + gt[ii * 300:(ii + 1) * 300] = i * 1000 + ii + info[f'{idx}_{ii}'] = { + 'mask': (points == ii), + 'label_id': i, + 'conf': 0.99 + } + predictions.append(info) + groundtruths.append(gt) + + return predictions, groundtruths + + +def _demo_mm_model_wrong_output(): + """Create a superset of inputs needed to run test or train batches.""" + n_points_list = [3300, 3000] + gt_labels_list = [[0, 0, 0, 0, 0, 0, 14, 14, 2, 2, 2], + [13, 13, 2, 1, 3, 3, 0, 0, 0, 0]] + predictions = [] + groundtruths = [] + + for idx, points_num in enumerate(n_points_list): + points = np.ones(points_num) * -1 + gt = np.ones(points_num) + info = {} + for ii, i in enumerate(gt_labels_list[idx]): + i = seg_valid_class_ids[i] + points[ii * 300:(ii + 1) * 300] = i + gt[ii * 300:(ii + 1) * 300] = i * 1000 + ii + info[f'{idx}_{ii}'] = { + 'mask': (points == i), + 'label_id': i, + 'conf': 0.99 + } + predictions.append(info) + groundtruths.append(gt) + + return predictions, groundtruths + + +def test_evaluate(): + predictions, groundtruths = _demo_mm_model_output() + instance_seg_metric = InstanceSeg(dataset_meta=dataset_meta) + res = instance_seg_metric(predictions, groundtruths) + assert isinstance(res, dict) + for label in [ + 'cabinet', 'bed', 'chair', 'sofa', 'showercurtrain', 'toilet' + ]: + metrics = res['classes'][label] + assert metrics['ap'] == 1.0 + assert metrics['ap50%'] == 1.0 + assert metrics['ap25%'] == 1.0 + predictions, groundtruths = _demo_mm_model_wrong_output() + res = instance_seg_metric(predictions, groundtruths) + assert abs(res['classes']['cabinet']['ap50%'] - 0.12) < 0.01 + assert abs(res['classes']['cabinet']['ap25%'] - 0.4125) < 0.01 + assert abs(res['classes']['bed']['ap50%'] - 1) < 0.01 + assert abs(res['classes']['bed']['ap25%'] - 1) < 0.01 + assert abs(res['classes']['chair']['ap50%'] - 0.375) < 0.01 + assert abs(res['classes']['chair']['ap25%'] - 0.785714) < 0.01