From 89a58c2468b74da95cfbf1c87927ad7285982a7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Asbj=C3=B8rn=20Munk?= <9844416+asbjrnmunk@users.noreply.github.com> Date: Fri, 4 Oct 2024 15:30:46 +0200 Subject: [PATCH] Do refactor --- .../functional/evaluation/evaluate_folder.py | 343 ++++++++++-------- yucca/pipeline/evaluation/YuccaEvaluator.py | 39 +- 2 files changed, 201 insertions(+), 181 deletions(-) diff --git a/yucca/functional/evaluation/evaluate_folder.py b/yucca/functional/evaluation/evaluate_folder.py index 7cb32ad3..6a901afb 100644 --- a/yucca/functional/evaluation/evaluate_folder.py +++ b/yucca/functional/evaluation/evaluate_folder.py @@ -13,195 +13,226 @@ from yucca.functional.evaluation.metrics import auroc -def evaluate_multilabel_folder_segm( - labels: dict, +def evaluate_folder_segm( + labels, metrics, subjects, folder_with_predictions, folder_with_ground_truth, as_binary: Optional[bool] = False, obj_metrics: Optional[bool] = False, - regions: Optional[list] = None, surface_metrics: Optional[bool] = False, surface_tol=1, + regions=None, + multilabel=False, ): - # predictions are multilabel at this point (c, h, w, d) - # ground truth MAY be converted, but may also be (h, w, d) in which case we use the label_regions - # to convert for multilabel evaluation - logging.info(f"Multilabel segmentation evaluation with regions: {regions} and labels: {labels}") - sys.stdout.flush() - resultdict = {} - meandict = {} + result_dict = {} + mean_dict = {} - labelarr = np.array(range(len(regions.keys()) + 1), dtype=np.uint8) + if multilabel: + assert regions is not None + # predictions are multilabel at this point (c, h, w, d) + # ground truth MAY be converted, but may also be (h, w, d) in which case we use the label_regions + # to convert for multilabel evaluation - for label in regions.keys(): - meandict[str(label)] = {k: [] for k in list(metrics.keys()) + obj_metrics + surface_metrics} + labels_from_regions = np.array(range(len(regions.keys()) + 1), dtype=np.uint8) + logging.info(f"Multilabel segmentation evaluation with regions: {regions} and labels: {labels}") + else: + logging.info(f"segmentation evaluation with labels: {labels}") - for case in tqdm(subjects, desc="Evaluating"): - casedict = {} - predpath = join(folder_with_predictions, case) - gtpath = join(folder_with_ground_truth, case) + for label in labels: + mean_dict[str(label)] = {k: [] for k in list(metrics.keys()) + obj_metrics + surface_metrics} + + evaluation_args = { + "folder_with_ground_truth": folder_with_ground_truth, + "folder_with_predictions": folder_with_predictions, + "labels": labels, + "as_binary": as_binary, + "obj_metrics": obj_metrics, + "surface_metrics": surface_metrics, + "surface_tol": surface_tol, + "metrics": metrics, + } - pred = nib.load(predpath) - spacing = get_nib_spacing(pred)[:3] - pred = pred.get_fdata().astype(np.uint8) - pred = pred.transpose([3, 0, 1, 2]) - gt = nib.load(gtpath).get_fdata() - - if len(pred.shape) == len(gt.shape) + 1: - # In thise case gt has not been converted to multilabel - assert ( - regions is not None - ), "Regions must be supplied if ground truth is not already multilabel (i.e. multiple channels)" - translated_regions = translate_region_labels(regions=regions, labels=labels) - gt = convert_labels_to_regions(gt[np.newaxis], translated_regions) - for i in range(len(regions.keys())): - pred[i] *= 1 + i - gt[i] *= 1 + i - - if as_binary: - cmat = confusion_matrix( - np.around(gt.flatten()).astype(bool).astype(np.uint8), - np.around(pred.flatten()).astype(bool).astype(np.uint8), - labels=labelarr, + for case in tqdm(subjects, desc="Evaluating"): + if multilabel: + case_dict = evaluate_multilabel_case_segm( + case, regions=regions, labels_from_regions=labels_from_regions, **evaluation_args ) else: - cmat = confusion_matrix( - np.around(gt.flatten()).astype(np.uint8), - np.around(pred.flatten()).astype(np.uint8), - labels=labelarr, - ) + case_dict = evaluate_case_segm(case, **evaluation_args) - for label, region_name in enumerate(regions.keys()): - label += 1 - labeldict = {} - - tp = cmat[label, label] - fp = sum(cmat[:, label]) - tp - fn = sum(cmat[label, :]) - tp - tn = np.sum(cmat) - tp - fp - fn # often a redundant and meaningless metric - for k, v in metrics.items(): - labeldict[k] = round(v(tp, fp, tn, fn), 4) - meandict[str(region_name)][k].append(labeldict[k]) - - if obj_metrics: - raise NotImplementedError - # now for the object metrics - # obj_labeldict = get_obj_stats_for_label(gt, pred, label, spacing=spacing, as_binary=as_binary) - # for k, v in obj_labeldict.items(): - # labeldict[k] = round(v, 4) - # meandict[str(label)][k].append(labeldict[k]) - - if surface_metrics: - if label == 0: - surface_labeldict = get_surface_metrics_for_label( - gt[label], pred[label], 0, spacing=spacing, tol=surface_tol, as_binary=as_binary - ) - else: - surface_labeldict = get_surface_metrics_for_label( - gt[label - 1], pred[label - 1], label, spacing=spacing, tol=surface_tol, as_binary=as_binary - ) - for k, v in surface_labeldict.items(): - labeldict[k] = round(v, 4) - meandict[str(region_name)][k].append(labeldict[k]) - casedict[str(region_name)] = labeldict - casedict["Prediction:"] = predpath - casedict["Ground Truth:"] = gtpath - - resultdict[case] = casedict - del pred, gt, cmat - - for label in regions.keys(): - meandict[str(label)] = { - k: round(np.nanmean(v), 4) if not np.all(np.isnan(v)) else 0 for k, v in meandict[str(label)].items() + result_dict[case] = case_dict + for label in case_dict.keys(): + for metric, val in case_dict[label].items(): + mean_dict[str(label)][metric].append(val) + + for label in labels: + mean_dict[str(label)] = { + k: round(np.nanmean(v), 4) if not np.all(np.isnan(v)) else 0 for k, v in mean_dict[str(label)].items() } - resultdict["mean"] = meandict - return resultdict + result_dict["mean"] = mean_dict + return result_dict -def evaluate_folder_segm( - labels, + +def evaluate_multilabel_case_segm( + case, metrics, - subjects, folder_with_predictions, folder_with_ground_truth, - as_binary: Optional[bool] = False, - obj_metrics: Optional[bool] = False, - surface_metrics: Optional[bool] = False, - surface_tol=1, + labels, + labels_from_regions, + as_binary, + obj_metrics, + surface_metrics, + surface_tol, + regions, ): - logging.info(f"segmentation evaluation with labels: {labels}") + case_dict = {} + predpath = join(folder_with_predictions, case) + gtpath = join(folder_with_ground_truth, case) + case_dict["Prediction:"] = predpath + case_dict["Ground Truth:"] = gtpath + + pred = nib.load(predpath) + spacing = get_nib_spacing(pred)[:3] + pred = pred.get_fdata().astype(np.uint8) + pred = pred.transpose([3, 0, 1, 2]) + gt = nib.load(gtpath).get_fdata() + + if len(pred.shape) == len(gt.shape) + 1: + # In thise case gt has not been converted to multilabel + assert ( + regions is not None + ), "Regions must be supplied if ground truth is not already multilabel (i.e. multiple channels)" + translated_regions = translate_region_labels(regions=regions, labels=labels) + gt = convert_labels_to_regions(gt[np.newaxis], translated_regions) + for i in range(len(regions.keys())): + pred[i] *= 1 + i + gt[i] *= 1 + i + + if as_binary: + cmat = confusion_matrix( + np.around(gt.flatten()).astype(bool).astype(np.uint8), + np.around(pred.flatten()).astype(bool).astype(np.uint8), + labels=labels_from_regions, + ) + else: + cmat = confusion_matrix( + np.around(gt.flatten()).astype(np.uint8), + np.around(pred.flatten()).astype(np.uint8), + labels=labels_from_regions, + ) + + for label, region_name in enumerate(regions.keys()): + label += 1 + labeldict = {} - sys.stdout.flush() - resultdict = {} - meandict = {} + tp = cmat[label, label] + fp = sum(cmat[:, label]) - tp + fn = sum(cmat[label, :]) - tp + tn = np.sum(cmat) - tp - fp - fn # often a redundant and meaningless metric + for k, v in metrics.items(): + labeldict[k] = round(v(tp, fp, tn, fn), 4) - for label in labels: - meandict[str(label)] = {k: [] for k in list(metrics.keys()) + obj_metrics + surface_metrics} + if obj_metrics: + raise NotImplementedError + # now for the object metrics + # obj_labeldict = get_obj_stats_for_label(gt, pred, label, spacing=spacing, as_binary=as_binary) + # for k, v in obj_labeldict.items(): + # labeldict[k] = round(v, 4) - for case in tqdm(subjects, desc="Evaluating"): - casedict = {} - predpath = join(folder_with_predictions, case) - gtpath = join(folder_with_ground_truth, case) + if surface_metrics: + if label == 0: + surface_labeldict = get_surface_metrics_for_label( + gt[label], pred[label], 0, spacing=spacing, tol=surface_tol, as_binary=as_binary + ) + else: + surface_labeldict = get_surface_metrics_for_label( + gt[label - 1], pred[label - 1], label, spacing=spacing, tol=surface_tol, as_binary=as_binary + ) + for k, v in surface_labeldict.items(): + labeldict[k] = round(v, 4) - pred = nib.load(predpath) - spacing = get_nib_spacing(pred) - pred = pred.get_fdata() - gt = nib.load(gtpath).get_fdata() + case_dict[str(region_name)] = labeldict - if as_binary: - cmat = confusion_matrix( - np.around(gt.flatten()).astype(bool).astype(np.uint8), - np.around(pred.flatten()).astype(bool).astype(np.uint8), - labels=labels, - ) - else: - cmat = confusion_matrix( - np.around(gt.flatten()).astype(np.uint8), - np.around(pred.flatten()).astype(np.uint8), - labels=labels, - ) + # Case dict contains for labels "0", ... and metrics "dice", "f1", ... + # { + # "0": { "dice": 0.1, "f1": 0.2, ... }, + # ... + # } + return case_dict - for label in labels: - labeldict = {} - - tp = cmat[label, label] - fp = sum(cmat[:, label]) - tp - fn = sum(cmat[label, :]) - tp - tn = np.sum(cmat) - tp - fp - fn # often a redundant and meaningless metric - for k, v in metrics.items(): - labeldict[k] = round(v(tp, fp, tn, fn), 4) - meandict[str(label)][k].append(labeldict[k]) - - if obj_metrics: - # now for the object metrics - obj_labeldict = get_obj_stats_for_label(gt, pred, label, spacing=spacing, as_binary=as_binary) - for k, v in obj_labeldict.items(): - labeldict[k] = round(v, 4) - meandict[str(label)][k].append(labeldict[k]) - - if surface_metrics: - surface_labeldict = get_surface_metrics_for_label( - gt, pred, label, spacing=spacing, tol=surface_tol, as_binary=as_binary - ) - for k, v in surface_labeldict.items(): - labeldict[k] = round(v, 4) - meandict[str(label)][k].append(labeldict[k]) - casedict[str(label)] = labeldict - casedict["Prediction:"] = predpath - casedict["Ground Truth:"] = gtpath - resultdict[case] = casedict - del pred, gt, cmat +def evaluate_case_segm( + case, + metrics, + folder_with_predictions, + folder_with_ground_truth, + labels, + as_binary, + obj_metrics, + surface_metrics, + surface_tol, +): + case_dict = {} + predpath = join(folder_with_predictions, case) + gtpath = join(folder_with_ground_truth, case) + + case_dict["Prediction:"] = predpath + case_dict["Ground Truth:"] = gtpath + + pred = nib.load(predpath) + spacing = get_nib_spacing(pred) + pred = pred.get_fdata() + gt = nib.load(gtpath).get_fdata() + + if as_binary: + cmat = confusion_matrix( + np.around(gt.flatten()).astype(bool).astype(np.uint8), + np.around(pred.flatten()).astype(bool).astype(np.uint8), + labels=labels, + ) + else: + cmat = confusion_matrix( + np.around(gt.flatten()).astype(np.uint8), + np.around(pred.flatten()).astype(np.uint8), + labels=labels, + ) for label in labels: - meandict[str(label)] = { - k: round(np.nanmean(v), 4) if not np.all(np.isnan(v)) else 0 for k, v in meandict[str(label)].items() - } - resultdict["mean"] = meandict - return resultdict + label_dict = {} + + tp = cmat[label, label] + fp = sum(cmat[:, label]) - tp + fn = sum(cmat[label, :]) - tp + tn = np.sum(cmat) - tp - fp - fn # often a redundant and meaningless metric + for k, v in metrics.items(): + label_dict[k] = round(v(tp, fp, tn, fn), 4) + + if obj_metrics: + # now for the object metrics + obj_labeldict = get_obj_stats_for_label(gt, pred, label, spacing=spacing, as_binary=as_binary) + for k, v in obj_labeldict.items(): + label_dict[k] = round(v, 4) + + if surface_metrics: + surface_labeldict = get_surface_metrics_for_label( + gt, pred, label, spacing=spacing, tol=surface_tol, as_binary=as_binary + ) + for k, v in surface_labeldict.items(): + label_dict[k] = round(v, 4) + + case_dict[str(label)] = label_dict + + # Case dict contains for labels "0", ... and metrics "dice", "f1", ... + # { + # "0": { "dice": 0.1, "f1": 0.2, ... }, + # ... + # } + return case_dict def evaluate_folder_cls( diff --git a/yucca/pipeline/evaluation/YuccaEvaluator.py b/yucca/pipeline/evaluation/YuccaEvaluator.py index c2004184..d66b936d 100644 --- a/yucca/pipeline/evaluation/YuccaEvaluator.py +++ b/yucca/pipeline/evaluation/YuccaEvaluator.py @@ -207,31 +207,20 @@ def evaluate_folder(self): folder_with_ground_truth=self.folder_with_ground_truth, ) elif self.task_type == "segmentation": - if self.regions is not None: - return evaluate_multilabel_folder_segm( - labels=self.labels, - metrics=self.metrics, - subjects=self.pred_subjects, - folder_with_predictions=self.folder_with_predictions, - folder_with_ground_truth=self.folder_with_ground_truth, - as_binary=self.as_binary, - obj_metrics=self.obj_metrics, - regions=self.regions, - surface_metrics=self.surface_metrics, - surface_tol=self.surface_tol, - ) - else: - return evaluate_folder_segm( - labels=self.labelarr, - metrics=self.metrics, - subjects=self.pred_subjects, - folder_with_predictions=self.folder_with_predictions, - folder_with_ground_truth=self.folder_with_ground_truth, - as_binary=self.as_binary, - obj_metrics=self.obj_metrics, - surface_metrics=self.surface_metrics, - surface_tol=self.surface_tol, - ) + multilabel = self.regions is not None + return evaluate_folder_segm( + labels=self.labels if multilabel else self.labelarr, + metrics=self.metrics, + subjects=self.pred_subjects, + folder_with_predictions=self.folder_with_predictions, + folder_with_ground_truth=self.folder_with_ground_truth, + as_binary=self.as_binary, + obj_metrics=self.obj_metrics, + surface_metrics=self.surface_metrics, + surface_tol=self.surface_tol, + regions=self.regions, + multilabel=multilabel, + ) else: raise NotImplementedError("Invalid task type")