Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up metrics computation by parallelizing across subjects #11

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
The script is compatible with both binary and multi-class segmentation tasks (e.g., nnunet region-based).
The metrics are computed for each unique label (class) in the reference (ground truth) image.

Authors: Jan Valosek
Authors: Jan Valosek, Naga Karthik
"""


Expand All @@ -41,6 +41,8 @@
import numpy as np
import nibabel as nib
import pandas as pd
from multiprocessing import Pool, cpu_count
from functools import partial

from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures as BPM

Expand Down Expand Up @@ -81,6 +83,8 @@ def get_parser():
'see: https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/metrics.html.')
parser.add_argument('-output', type=str, default='metrics.csv', required=False,
help='Path to the output CSV file to save the metrics. Default: metrics.csv')
parser.add_argument('-jobs', type=int, default=cpu_count()//8, required=False,
help='Number of CPU cores to use in parallel. Default: cpu_count()//8.')

return parser

Expand Down Expand Up @@ -130,9 +134,7 @@ def compute_metrics_single_subject(prediction, reference, metrics):
:param metrics: list of metrics to compute
"""
# load nifti images
print(f'Processing...')
print(f'\tPrediction: {os.path.basename(prediction)}')
print(f'\tReference: {os.path.basename(reference)}')
print(f'\nProcessing:\n\tPrediction: {os.path.basename(prediction)}\n\tReference: {os.path.basename(reference)}')
prediction_data = load_nifti_image(prediction)
reference_data = load_nifti_image(reference)

Expand All @@ -159,7 +161,6 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# by doing this, we can compute metrics for each label separately, e.g., separately for spinal cord and lesions
for label in unique_labels:
# create binary masks for the current label
print(f'\tLabel {label}')
prediction_data_label = np.array(prediction_data == label, dtype=float)
reference_data_label = np.array(reference_data == label, dtype=float)

Expand All @@ -171,12 +172,9 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# add the metrics to the output dictionary
metrics_dict[label] = dict_seg

if label == max(unique_labels):
break # break to loop to avoid processing the background label ("else" block)
# Special case when both the reference and prediction images are empty
else:
label = 1
print(f'\tLabel {label} -- both the reference and prediction are empty')
bpm = BPM(prediction_data, reference_data, measures=metrics)
dict_seg = bpm.to_dict_meas()

Expand Down Expand Up @@ -216,8 +214,14 @@ def build_output_dataframe(output_list):
return df


def main():
def process_subject(prediction_file, reference_file, metrics):
"""
Wrapper function to process a single subject.
"""
return compute_metrics_single_subject(prediction_file, reference_file, metrics)


def main():
# parse command line arguments
parser = get_parser()
args = parser.parse_args()
Expand All @@ -227,19 +231,22 @@ def main():

# Print the metrics to be computed
print(f'Computing metrics: {args.metrics}')
print(f'Using {args.jobs} CPU cores in parallel ...')

# Args.prediction and args.reference are paths to folders with multiple nii.gz files (i.e., MULTIPLE subjects)
if os.path.isdir(args.prediction) and os.path.isdir(args.reference):
# Get all files in the directories
prediction_files, reference_files = get_images_in_folder(args.prediction, args.reference)
# Loop over the subjects
for i in range(len(prediction_files)):
# Compute metrics for each subject
metrics_dict = compute_metrics_single_subject(prediction_files[i], reference_files[i], args.metrics)
# Append the output dictionary (representing a single reference-prediction pair per subject) to the
# output_list
output_list.append(metrics_dict)
# Args.prediction and args.reference are paths nii.gz files from a SINGLE subject

# Use multiprocessing to parallelize the computation
with Pool(args.jobs) as pool:
# Create a partial function to pass the metrics argument to the process_subject function
func = partial(process_subject, metrics=args.metrics)
# Compute metrics for each subject in parallel
results = pool.starmap(func, zip(prediction_files, reference_files))

# Collect the results
output_list.extend(results)
else:
metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics)
# Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list
Expand All @@ -252,6 +259,11 @@ def main():
df_mean = (df.drop(columns=['reference', 'prediction', 'EmptyRef', 'EmptyPred']).groupby('label').
agg(['mean', 'std']).reset_index())

# Convert multi-index to flat index
df_mean.columns = ['_'.join(col).strip() for col in df_mean.columns.values]
# Rename column `label_` back to `label`
df_mean.rename(columns={'label_': 'label'}, inplace=True)

# Rename columns
df.rename(columns={metric: METRICS_TO_NAME[metric] for metric in METRICS_TO_NAME}, inplace=True)
df_mean.rename(columns={metric: METRICS_TO_NAME[metric] for metric in METRICS_TO_NAME}, inplace=True)
Expand Down
10 changes: 5 additions & 5 deletions test/test_metrics/test_pairwise_measures_neuropoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,11 @@ def test_non_empty_ref_and_pred_multi_class(self):
Multi-class (i.e., voxels with values 1 and 2, e.g., region-based nnUNet training)
"""

expected_metrics = {1.0: {'dsc': 0.25,
'fbeta': 0.2500000055879354,
'nsd': 0.5,
'vol_diff': 2.0,
'rel_vol_error': 200.0,
expected_metrics = {1.0: {'dsc': 0.6521739130434783,
'fbeta': 0.5769230751596257,
'nsd': 0.23232323232323232,
'vol_diff': 2.6,
'rel_vol_error': 260.0,
'EmptyRef': False,
'EmptyPred': False,
'lesion_ppv': 1.0,
Expand Down
Loading