Skip to content

Commit

Permalink
Changes from PR and fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
asbjrnmunk committed Oct 4, 2024
1 parent 6d6445d commit 614bab9
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions yucca/functional/evaluation/evaluate_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@


def evaluate_folder_segm(
folder_with_predictions: str,
folder_with_ground_truth: str,
labels: dict,
metrics: dict,
subjects: list,
folder_with_predictions: str,
folder_with_ground_truth: str,
as_binary: Optional[bool] = False,
multilabel: bool = False,
obj_metrics: Optional[bool] = False,
regions: Optional[dict] = None,
surface_metrics: Optional[bool] = False,
surface_tol: int = 1,
regions: Optional[dict] = None,
multilabel: bool = False,
):
sys.stdout.flush()
result_dict = {}
Expand Down Expand Up @@ -64,7 +64,10 @@ def evaluate_folder_segm(
case_dict = evaluate_case_segm(case, **evaluation_args)

result_dict[case] = case_dict
for label in labels:

metric_labels = [key for key in case_dict.keys() if key not in ["prediction_path", "ground_truth_path"]]

for label in metric_labels:
for metric, val in case_dict[str(label)].items():
mean_dict[str(label)][metric].append(val)

Expand All @@ -79,24 +82,24 @@ def evaluate_folder_segm(

def evaluate_multilabel_case_segm(
case: str,
metrics: dict,
folder_with_predictions: str,
folder_with_ground_truth: str,
labels: dict,
labels_from_regions: np.array,
as_binary: bool,
obj_metrics: bool,
surface_metrics: bool,
surface_tol: int,
regions: dict,
metrics: dict,
as_binary: Optional[bool] = False,
obj_metrics: Optional[bool] = False,
regions: Optional[dict] = None,
surface_metrics: Optional[bool] = False,
surface_tol: int = 1,
):
assert regions is not None

case_dict = {}
predpath = join(folder_with_predictions, case)
gtpath = join(folder_with_ground_truth, case)
case_dict["Prediction:"] = predpath
case_dict["Ground Truth:"] = gtpath
case_dict["prediction_path"] = predpath
case_dict["ground_truth_path"] = gtpath

pred = nib.load(predpath)
spacing = get_nib_spacing(pred)[:3]
Expand Down Expand Up @@ -170,21 +173,21 @@ def evaluate_multilabel_case_segm(

def evaluate_case_segm(
case: str,
metrics: dict,
folder_with_predictions: str,
folder_with_ground_truth: str,
labels: dict,
as_binary: bool,
obj_metrics: bool,
surface_metrics: bool,
surface_tol: int,
metrics: dict,
as_binary: Optional[bool] = False,
obj_metrics: Optional[bool] = False,
surface_metrics: Optional[bool] = False,
surface_tol: int = 1,
):
case_dict = {}
predpath = join(folder_with_predictions, case)
gtpath = join(folder_with_ground_truth, case)

case_dict["Prediction:"] = predpath
case_dict["Ground Truth:"] = gtpath
case_dict["prediction_path"] = predpath
case_dict["ground_truth_path"] = gtpath

pred = nib.load(predpath)
spacing = get_nib_spacing(pred)
Expand Down

0 comments on commit 614bab9

Please sign in to comment.