Skip to content

Commit

Permalink
add overlap ratio to arg parser, and input arg to pairwise measures c…
Browse files Browse the repository at this point in the history
…lass
  • Loading branch information
naga-karthik committed Dec 9, 2024
1 parent e7bfea7 commit 53091b5
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def get_parser():
help='Path to the output CSV file to save the metrics. Default: metrics.csv')
parser.add_argument('-jobs', type=int, default=cpu_count()//8, required=False,
help='Number of CPU cores to use in parallel. Default: cpu_count()//8.')
parser.add_argument('--overlap-ratio', type=float, default=0.1, required=False,
help='Overlap ratio between the ground-truth and prediction to be considered as true positive (TP).'
'Used only in counting TPs in lesion-wise metrics. Default: 0.1')

return parser

Expand Down Expand Up @@ -126,7 +129,7 @@ def get_images_in_folder(prediction, reference):
return prediction_files, reference_files


def compute_metrics_single_subject(prediction, reference, metrics):
def compute_metrics_single_subject(prediction, reference, metrics, overlap_ratio):
"""
Compute MetricsReloaded metrics for a single subject
:param prediction: path to the nifti image with the prediction
Expand Down Expand Up @@ -169,7 +172,7 @@ def compute_metrics_single_subject(prediction, reference, metrics):
prediction_data_label = np.array(prediction_data == label, dtype=float)
reference_data_label = np.array(reference_data == label, dtype=float)

bpm = BPM(prediction_data_label, reference_data_label, measures=metrics)
bpm = BPM(prediction_data_label, reference_data_label, measures=metrics, overlap_ratio=overlap_ratio)
dict_seg = bpm.to_dict_meas()
# Store info whether the reference or prediction is empty
dict_seg['EmptyRef'] = bpm.flag_empty_ref
Expand All @@ -180,7 +183,7 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# Special case when both the reference and prediction images are empty
else:
label = 1.0
bpm = BPM(prediction_data, reference_data, measures=metrics)
bpm = BPM(prediction_data, reference_data, measures=metrics, overlap_ratio=overlap_ratio)
dict_seg = bpm.to_dict_meas()

# Store info whether the reference or prediction is empty
Expand Down Expand Up @@ -219,11 +222,11 @@ def build_output_dataframe(output_list):
return df


def process_subject(prediction_file, reference_file, metrics):
def process_subject(prediction_file, reference_file, metrics, overlap_ratio):
"""
Wrapper function to process a single subject.
"""
return compute_metrics_single_subject(prediction_file, reference_file, metrics)
return compute_metrics_single_subject(prediction_file, reference_file, metrics, overlap_ratio)


def main():
Expand All @@ -246,14 +249,14 @@ def main():
# Use multiprocessing to parallelize the computation
with Pool(args.jobs) as pool:
# Create a partial function to pass the metrics argument to the process_subject function
func = partial(process_subject, metrics=args.metrics)
func = partial(process_subject, metrics=args.metrics, overlap_ratio=args.overlap_ratio)
# Compute metrics for each subject in parallel
results = pool.starmap(func, zip(prediction_files, reference_files))

# Collect the results
output_list.extend(results)
else:
metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics)
metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics, args.overlap_ratio)
# Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list
output_list.append(metrics_dict)

Expand Down

0 comments on commit 53091b5

Please sign in to comment.