Skip to content

Commit

Permalink
Do refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
asbjrnmunk committed Oct 4, 2024
1 parent 793ef2c commit 89a58c2
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 181 deletions.
343 changes: 187 additions & 156 deletions yucca/functional/evaluation/evaluate_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,195 +13,226 @@
from yucca.functional.evaluation.metrics import auroc


def evaluate_multilabel_folder_segm(
labels: dict,
def evaluate_folder_segm(
labels,
metrics,
subjects,
folder_with_predictions,
folder_with_ground_truth,
as_binary: Optional[bool] = False,
obj_metrics: Optional[bool] = False,
regions: Optional[list] = None,
surface_metrics: Optional[bool] = False,
surface_tol=1,
regions=None,
multilabel=False,
):
# predictions are multilabel at this point (c, h, w, d)
# ground truth MAY be converted, but may also be (h, w, d) in which case we use the label_regions
# to convert for multilabel evaluation
logging.info(f"Multilabel segmentation evaluation with regions: {regions} and labels: {labels}")

sys.stdout.flush()
resultdict = {}
meandict = {}
result_dict = {}
mean_dict = {}

labelarr = np.array(range(len(regions.keys()) + 1), dtype=np.uint8)
if multilabel:
assert regions is not None
# predictions are multilabel at this point (c, h, w, d)
# ground truth MAY be converted, but may also be (h, w, d) in which case we use the label_regions
# to convert for multilabel evaluation

for label in regions.keys():
meandict[str(label)] = {k: [] for k in list(metrics.keys()) + obj_metrics + surface_metrics}
labels_from_regions = np.array(range(len(regions.keys()) + 1), dtype=np.uint8)
logging.info(f"Multilabel segmentation evaluation with regions: {regions} and labels: {labels}")
else:
logging.info(f"segmentation evaluation with labels: {labels}")

for case in tqdm(subjects, desc="Evaluating"):
casedict = {}
predpath = join(folder_with_predictions, case)
gtpath = join(folder_with_ground_truth, case)
for label in labels:
mean_dict[str(label)] = {k: [] for k in list(metrics.keys()) + obj_metrics + surface_metrics}

evaluation_args = {
"folder_with_ground_truth": folder_with_ground_truth,
"folder_with_predictions": folder_with_predictions,
"labels": labels,
"as_binary": as_binary,
"obj_metrics": obj_metrics,
"surface_metrics": surface_metrics,
"surface_tol": surface_tol,
"metrics": metrics,
}

pred = nib.load(predpath)
spacing = get_nib_spacing(pred)[:3]
pred = pred.get_fdata().astype(np.uint8)
pred = pred.transpose([3, 0, 1, 2])
gt = nib.load(gtpath).get_fdata()

if len(pred.shape) == len(gt.shape) + 1:
# In thise case gt has not been converted to multilabel
assert (
regions is not None
), "Regions must be supplied if ground truth is not already multilabel (i.e. multiple channels)"
translated_regions = translate_region_labels(regions=regions, labels=labels)
gt = convert_labels_to_regions(gt[np.newaxis], translated_regions)
for i in range(len(regions.keys())):
pred[i] *= 1 + i
gt[i] *= 1 + i

if as_binary:
cmat = confusion_matrix(
np.around(gt.flatten()).astype(bool).astype(np.uint8),
np.around(pred.flatten()).astype(bool).astype(np.uint8),
labels=labelarr,
for case in tqdm(subjects, desc="Evaluating"):
if multilabel:
case_dict = evaluate_multilabel_case_segm(
case, regions=regions, labels_from_regions=labels_from_regions, **evaluation_args
)
else:
cmat = confusion_matrix(
np.around(gt.flatten()).astype(np.uint8),
np.around(pred.flatten()).astype(np.uint8),
labels=labelarr,
)
case_dict = evaluate_case_segm(case, **evaluation_args)

for label, region_name in enumerate(regions.keys()):
label += 1
labeldict = {}

tp = cmat[label, label]
fp = sum(cmat[:, label]) - tp
fn = sum(cmat[label, :]) - tp
tn = np.sum(cmat) - tp - fp - fn # often a redundant and meaningless metric
for k, v in metrics.items():
labeldict[k] = round(v(tp, fp, tn, fn), 4)
meandict[str(region_name)][k].append(labeldict[k])

if obj_metrics:
raise NotImplementedError
# now for the object metrics
# obj_labeldict = get_obj_stats_for_label(gt, pred, label, spacing=spacing, as_binary=as_binary)
# for k, v in obj_labeldict.items():
# labeldict[k] = round(v, 4)
# meandict[str(label)][k].append(labeldict[k])

if surface_metrics:
if label == 0:
surface_labeldict = get_surface_metrics_for_label(
gt[label], pred[label], 0, spacing=spacing, tol=surface_tol, as_binary=as_binary
)
else:
surface_labeldict = get_surface_metrics_for_label(
gt[label - 1], pred[label - 1], label, spacing=spacing, tol=surface_tol, as_binary=as_binary
)
for k, v in surface_labeldict.items():
labeldict[k] = round(v, 4)
meandict[str(region_name)][k].append(labeldict[k])
casedict[str(region_name)] = labeldict
casedict["Prediction:"] = predpath
casedict["Ground Truth:"] = gtpath

resultdict[case] = casedict
del pred, gt, cmat

for label in regions.keys():
meandict[str(label)] = {
k: round(np.nanmean(v), 4) if not np.all(np.isnan(v)) else 0 for k, v in meandict[str(label)].items()
result_dict[case] = case_dict
for label in case_dict.keys():
for metric, val in case_dict[label].items():
mean_dict[str(label)][metric].append(val)

for label in labels:
mean_dict[str(label)] = {
k: round(np.nanmean(v), 4) if not np.all(np.isnan(v)) else 0 for k, v in mean_dict[str(label)].items()
}
resultdict["mean"] = meandict
return resultdict
result_dict["mean"] = mean_dict

return result_dict

def evaluate_folder_segm(
labels,

def evaluate_multilabel_case_segm(
case,
metrics,
subjects,
folder_with_predictions,
folder_with_ground_truth,
as_binary: Optional[bool] = False,
obj_metrics: Optional[bool] = False,
surface_metrics: Optional[bool] = False,
surface_tol=1,
labels,
labels_from_regions,
as_binary,
obj_metrics,
surface_metrics,
surface_tol,
regions,
):
logging.info(f"segmentation evaluation with labels: {labels}")
case_dict = {}
predpath = join(folder_with_predictions, case)
gtpath = join(folder_with_ground_truth, case)
case_dict["Prediction:"] = predpath
case_dict["Ground Truth:"] = gtpath

pred = nib.load(predpath)
spacing = get_nib_spacing(pred)[:3]
pred = pred.get_fdata().astype(np.uint8)
pred = pred.transpose([3, 0, 1, 2])
gt = nib.load(gtpath).get_fdata()

if len(pred.shape) == len(gt.shape) + 1:
# In thise case gt has not been converted to multilabel
assert (
regions is not None
), "Regions must be supplied if ground truth is not already multilabel (i.e. multiple channels)"
translated_regions = translate_region_labels(regions=regions, labels=labels)
gt = convert_labels_to_regions(gt[np.newaxis], translated_regions)
for i in range(len(regions.keys())):
pred[i] *= 1 + i
gt[i] *= 1 + i

if as_binary:
cmat = confusion_matrix(
np.around(gt.flatten()).astype(bool).astype(np.uint8),
np.around(pred.flatten()).astype(bool).astype(np.uint8),
labels=labels_from_regions,
)
else:
cmat = confusion_matrix(
np.around(gt.flatten()).astype(np.uint8),
np.around(pred.flatten()).astype(np.uint8),
labels=labels_from_regions,
)

for label, region_name in enumerate(regions.keys()):
label += 1
labeldict = {}

sys.stdout.flush()
resultdict = {}
meandict = {}
tp = cmat[label, label]
fp = sum(cmat[:, label]) - tp
fn = sum(cmat[label, :]) - tp
tn = np.sum(cmat) - tp - fp - fn # often a redundant and meaningless metric
for k, v in metrics.items():
labeldict[k] = round(v(tp, fp, tn, fn), 4)

for label in labels:
meandict[str(label)] = {k: [] for k in list(metrics.keys()) + obj_metrics + surface_metrics}
if obj_metrics:
raise NotImplementedError
# now for the object metrics
# obj_labeldict = get_obj_stats_for_label(gt, pred, label, spacing=spacing, as_binary=as_binary)
# for k, v in obj_labeldict.items():
# labeldict[k] = round(v, 4)

for case in tqdm(subjects, desc="Evaluating"):
casedict = {}
predpath = join(folder_with_predictions, case)
gtpath = join(folder_with_ground_truth, case)
if surface_metrics:
if label == 0:
surface_labeldict = get_surface_metrics_for_label(
gt[label], pred[label], 0, spacing=spacing, tol=surface_tol, as_binary=as_binary
)
else:
surface_labeldict = get_surface_metrics_for_label(
gt[label - 1], pred[label - 1], label, spacing=spacing, tol=surface_tol, as_binary=as_binary
)
for k, v in surface_labeldict.items():
labeldict[k] = round(v, 4)

pred = nib.load(predpath)
spacing = get_nib_spacing(pred)
pred = pred.get_fdata()
gt = nib.load(gtpath).get_fdata()
case_dict[str(region_name)] = labeldict

if as_binary:
cmat = confusion_matrix(
np.around(gt.flatten()).astype(bool).astype(np.uint8),
np.around(pred.flatten()).astype(bool).astype(np.uint8),
labels=labels,
)
else:
cmat = confusion_matrix(
np.around(gt.flatten()).astype(np.uint8),
np.around(pred.flatten()).astype(np.uint8),
labels=labels,
)
# Case dict contains for labels "0", ... and metrics "dice", "f1", ...
# {
# "0": { "dice": 0.1, "f1": 0.2, ... },
# ...
# }
return case_dict

for label in labels:
labeldict = {}

tp = cmat[label, label]
fp = sum(cmat[:, label]) - tp
fn = sum(cmat[label, :]) - tp
tn = np.sum(cmat) - tp - fp - fn # often a redundant and meaningless metric
for k, v in metrics.items():
labeldict[k] = round(v(tp, fp, tn, fn), 4)
meandict[str(label)][k].append(labeldict[k])

if obj_metrics:
# now for the object metrics
obj_labeldict = get_obj_stats_for_label(gt, pred, label, spacing=spacing, as_binary=as_binary)
for k, v in obj_labeldict.items():
labeldict[k] = round(v, 4)
meandict[str(label)][k].append(labeldict[k])

if surface_metrics:
surface_labeldict = get_surface_metrics_for_label(
gt, pred, label, spacing=spacing, tol=surface_tol, as_binary=as_binary
)
for k, v in surface_labeldict.items():
labeldict[k] = round(v, 4)
meandict[str(label)][k].append(labeldict[k])
casedict[str(label)] = labeldict
casedict["Prediction:"] = predpath
casedict["Ground Truth:"] = gtpath

resultdict[case] = casedict
del pred, gt, cmat
def evaluate_case_segm(
case,
metrics,
folder_with_predictions,
folder_with_ground_truth,
labels,
as_binary,
obj_metrics,
surface_metrics,
surface_tol,
):
case_dict = {}
predpath = join(folder_with_predictions, case)
gtpath = join(folder_with_ground_truth, case)

case_dict["Prediction:"] = predpath
case_dict["Ground Truth:"] = gtpath

pred = nib.load(predpath)
spacing = get_nib_spacing(pred)
pred = pred.get_fdata()
gt = nib.load(gtpath).get_fdata()

if as_binary:
cmat = confusion_matrix(
np.around(gt.flatten()).astype(bool).astype(np.uint8),
np.around(pred.flatten()).astype(bool).astype(np.uint8),
labels=labels,
)
else:
cmat = confusion_matrix(
np.around(gt.flatten()).astype(np.uint8),
np.around(pred.flatten()).astype(np.uint8),
labels=labels,
)

for label in labels:
meandict[str(label)] = {
k: round(np.nanmean(v), 4) if not np.all(np.isnan(v)) else 0 for k, v in meandict[str(label)].items()
}
resultdict["mean"] = meandict
return resultdict
label_dict = {}

tp = cmat[label, label]
fp = sum(cmat[:, label]) - tp
fn = sum(cmat[label, :]) - tp
tn = np.sum(cmat) - tp - fp - fn # often a redundant and meaningless metric
for k, v in metrics.items():
label_dict[k] = round(v(tp, fp, tn, fn), 4)

if obj_metrics:
# now for the object metrics
obj_labeldict = get_obj_stats_for_label(gt, pred, label, spacing=spacing, as_binary=as_binary)
for k, v in obj_labeldict.items():
label_dict[k] = round(v, 4)

if surface_metrics:
surface_labeldict = get_surface_metrics_for_label(
gt, pred, label, spacing=spacing, tol=surface_tol, as_binary=as_binary
)
for k, v in surface_labeldict.items():
label_dict[k] = round(v, 4)

case_dict[str(label)] = label_dict

# Case dict contains for labels "0", ... and metrics "dice", "f1", ...
# {
# "0": { "dice": 0.1, "f1": 0.2, ... },
# ...
# }
return case_dict


def evaluate_folder_cls(
Expand Down
Loading

0 comments on commit 89a58c2

Please sign in to comment.