Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match GTs and predictions based on participant_id, acq_id, and run_id #17

Merged
merged 11 commits into from
Dec 11, 2024
87 changes: 0 additions & 87 deletions .github/workflows/python-app.yml

This file was deleted.

77 changes: 65 additions & 12 deletions compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
python compute_metrics_reloaded.py
-reference /path/to/reference
-prediction /path/to/prediction
NOTE: The prediction and reference files are matched based on the participant_id, acq_id, and run_id.

The metrics to be computed can be specified using the `-metrics` argument. For example, to compute only the Dice
similarity coefficient (DSC) and Normalized surface distance (NSD), use:
Expand All @@ -37,6 +38,7 @@


import os
import re
import argparse
import numpy as np
import nibabel as nib
Expand Down Expand Up @@ -103,25 +105,76 @@ def load_nifti_image(file_path):
return nifti_image.get_fdata()


def get_images_in_folder(prediction, reference):
def fetch_bids_compatible_keys(filename_path, prefix='sub-'):
"""
Get all files (predictions and references/ground truths) in the input directories
Get participant_id, session_id, acq_id, chunk_id and run_id from the input BIDS-compatible filename or file path
The function works both on absolute file paths as well as filenames
:param filename_path: input nifti filename (e.g., sub-001_ses-01_T1w.nii.gz) or file path
:param prefix: prefix of the participant ID in the filename (default: 'sub-')
(e.g., /home/user/bids/sub-001/ses-01/anat/sub-001_ses-01_T1w.nii.gz
:return: participant_id: participant ID (e.g., sub-001)
:return: session_id: session ID (e.g., ses-01)
:return: acq_id: acquisition ID (e.g., acq-01)
:return: chunk_id: chunk ID (e.g., chunk-1)
:return: run_id: run ID (e.g., run-01)
"""

participant = re.search(f'{prefix}(.*?)[_/]', filename_path) # [_/] means either underscore or slash
participant_id = participant.group(0)[:-1] if participant else "" # [:-1] removes the last underscore or slash

session = re.search('ses-(.*?)[_/]', filename_path) # [_/] means either underscore or slash
session_id = session.group(0)[:-1] if session else "" # [:-1] removes the last underscore or slash

acquisition = re.search('acq-(.*?)[_/]', filename_path) # [_/] means either underscore or slash
acq_id = acquisition.group(0)[:-1] if acquisition else "" # [:-1] removes the last underscore or slash

chunk = re.search('chunk-(.*?)[_/]', filename_path) # [_/] means either underscore or slash
chunk_id = chunk.group(0)[:-1] if chunk else "" # [:-1] removes the last underscore or slash

run = re.search('run-(.*?)[_/]', filename_path) # [_/] means either underscore or slash
run_id = run.group(0)[:-1] if run else "" # [:-1] removes the last underscore or slash

# REGEX explanation
# . - match any character (except newline)
# *? - match the previous element as few times as possible (zero or more times)

return participant_id, session_id, acq_id, chunk_id, run_id


def get_images(prediction, reference):
"""
Get all files (predictions and references/ground truths) in the input directories.
The prediction and reference files are matched based on the participant_id, acq_id, and run_id.
:param prediction: path to the directory with prediction files
:param reference: path to the directory with reference (ground truth) files
:return: list of prediction files, list of reference/ground truth files
"""
# Get all files in the directories
prediction_files = [os.path.join(prediction, f) for f in os.listdir(prediction) if f.endswith('.nii.gz')]
reference_files = [os.path.join(reference, f) for f in os.listdir(reference) if f.endswith('.nii.gz')]
# Check if the number of files in the directories is the same
if len(prediction_files) != len(reference_files):
raise ValueError(f'The number of files in the directories is different. '
f'Prediction files: {len(prediction_files)}, Reference files: {len(reference_files)}')
print(f'Found {len(prediction_files)} files in the directories.')
# Sort the files
# NOTE: Hopefully, the files are named in the same order in both directories
prediction_files.sort()
reference_files.sort()

if not prediction_files:
raise FileNotFoundError(f'No prediction files found in {prediction}.')
if not reference_files:
raise FileNotFoundError(f'No reference (ground truths) files found in {reference}.')

# Create dataframe for prediction_files with participant_id, acq_id, run_id
df_pred = pd.DataFrame(prediction_files, columns=['filename'])
df_pred['participant_id'], df_pred['session_id'], df_pred['acq_id'], df_pred['chunk_id'], df_pred['run_id'] = zip(*df_pred['filename'].apply(fetch_bids_compatible_keys))

# Create dataframe for reference_files with participant_id, acq_id, run_id
df_ref = pd.DataFrame(reference_files, columns=['filename'])
df_ref['participant_id'], df_ref['session_id'], df_ref['acq_id'], df_ref['chunk_id'], df_ref['run_id'] = zip(*df_ref['filename'].apply(fetch_bids_compatible_keys))

# Merge the two dataframes on participant_id, acq_id, run_id
df = pd.merge(df_pred, df_ref, on=['participant_id', 'session_id', 'acq_id', 'chunk_id', 'run_id'], how='outer', suffixes=('_pred', '_ref'))
# Drop 'participant_id', 'acq_id', 'run_id'
df.drop(['participant_id', 'session_id', 'acq_id', 'chunk_id', 'run_id'], axis=1, inplace=True)
# Drop rows with NaN values. In other words, keep only the rows where both prediction and reference files exist
df.dropna(inplace=True)

prediction_files = df['filename_pred'].tolist()
reference_files = df['filename_ref'].tolist()

return prediction_files, reference_files

Expand Down Expand Up @@ -236,7 +289,7 @@ def main():
# Args.prediction and args.reference are paths to folders with multiple nii.gz files (i.e., MULTIPLE subjects)
if os.path.isdir(args.prediction) and os.path.isdir(args.reference):
# Get all files in the directories
prediction_files, reference_files = get_images_in_folder(args.prediction, args.reference)
prediction_files, reference_files = get_images(args.prediction, args.reference)

# Use multiprocessing to parallelize the computation
with Pool(args.jobs) as pool:
Expand Down
127 changes: 126 additions & 1 deletion test/test_metrics/test_pairwise_measures_neuropoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import numpy as np
import nibabel as nib
from compute_metrics_reloaded import compute_metrics_single_subject
from compute_metrics_reloaded import compute_metrics_single_subject, get_images, fetch_bids_compatible_keys
import tempfile

METRICS = ['dsc', 'fbeta', 'nsd', 'vol_diff', 'rel_vol_error', 'lesion_ppv', 'lesion_sensitivity', 'lesion_f1_score',
Expand Down Expand Up @@ -358,6 +358,131 @@ def test_non_empty_ref_and_pred_with_full_overlap(self):
# Assert metrics
self.assert_metrics(metrics_dict, expected_metrics)

class TestGetImages(unittest.TestCase):
def setUp(self):
"""
Create temporary directories and files for testing.
"""
self.pred_dir = tempfile.TemporaryDirectory()
self.ref_dir = tempfile.TemporaryDirectory()

def tearDown(self):
"""
Cleanup temporary directories and files after tests.
"""
self.pred_dir.cleanup()
self.ref_dir.cleanup()

def create_temp_file(self, directory, filename):
"""
Create a temporary file in the given directory with the specified filename.
"""
file_path = os.path.join(directory, filename)
with open(file_path, 'w') as f:
f.write('dummy content')
return file_path

def test_matching_files(self):
"""
Test matching files based on participant_id, acq_id, and run_id.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_acq-01_chunk-1_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_ses-01_acq-01_chunk-1_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)

def test_mismatched_files(self):
"""
Test when no files match based on the criteria.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_acq-01_chunk-1_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-02_ses-01_acq-02_chunk-1_run-02_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 0)
self.assertEqual(len(ref_files), 0)

def test_ses_id_empty(self):
"""
Test when ses_id is empty.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_chunk-1_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_chunk-1_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_acq-01_chunk-1_run-01_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_acq-01_chunk-1_run-01_ref.nii.gz", ref_files[0])

def test_acq_id_empty(self):
"""
Test when acq_id is empty.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_chunk-1_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_ses-01_chunk-1_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_ses-01_chunk-1_run-01_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_ses-01_chunk-1_run-01_ref.nii.gz", ref_files[0])

def test_chunk_id_empty(self):
"""
Test when chunk_id is empty in the filenames.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_acq-01_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_ses-01_acq-01_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)

# Assert the matched files
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_ses-01_acq-01_run-01_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_ses-01_acq-01_run-01_ref.nii.gz", ref_files[0])

def test_run_id_empty(self):
"""
Test when run_id is empty in the filenames.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_ses-01_acq-01_chunk-1_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_ses-01_acq-01_chunk-1_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)

# Assert the matched files
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_ses-01_acq-01_chunk-1_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_ses-01_acq-01_chunk-1_ref.nii.gz", ref_files[0])

def test_no_files(self):
"""
Test when there are no files in the directories.
Ensure that FileNotFoundError is raised.
"""
with self.assertRaises(FileNotFoundError) as context:
get_images(self.pred_dir.name, self.ref_dir.name)
# Check the exception message
self.assertIn(f'No prediction files found in {self.pred_dir.name}', str(context.exception))

def test_partial_matching(self):
"""
Test when some files match and some do not.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_run-01_ref.nii.gz")
# The following file will not be included in the lists below as there is no matching reference (GT) file
self.create_temp_file(self.pred_dir.name, "sub-02_acq-02_run-02_pred.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)


if __name__ == '__main__':
unittest.main()
Loading