Skip to content

Commit

Permalink
Merge pull request #44 from ivadomed/rb/data_type_generalize
Browse files Browse the repository at this point in the history
`convert_bids_to_nnUNetV2.py` - add support of data types other than `anat`
  • Loading branch information
rohanbanerjee authored Feb 19, 2024
2 parents e9cbfd8 + 8937bb5 commit 0bbeaa2
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 12 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This is a basic workflow to help you get started with Actions

name: CI

# Controls when the action will run.
on:
# Triggers the workflow on push or pull request events but only for the master branch
push:
branches: [ master ]
pull_request:
branches: '*'

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: ubuntu-latest

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v4

- name: Install Python 3
uses: actions/setup-python@v4
with:
python-version: 3.8.18

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r dataset_conversion/requirements.txt
- name: Run tests with unittest
run: |
python -m unittest tests/test_convert_bids_to_nnUNetV2.py
77 changes: 65 additions & 12 deletions dataset_conversion/convert_bids_to_nnUNetV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,46 @@
Converts BIDS-structured dataset to the nnUNetv2 dataset format. Full details about
the format can be found here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md
Example of the input BIDS dataset structure:
...
├── sub-045
│ └── anat
│ └── sub-045_T2w.nii.gz
├── sub-046
│ └── anat
│ └── sub-046_T2w.nii.gz
...
├── derivatives
│ └── labels
│ ├── sub-045
│ │ └── anat
│ │ ├── sub-045_T2w_lesion-manual.json
│ │ ├── sub-045_T2w_lesion-manual.nii.gz
│ │ ├── sub-045_T2w_seg-manual.json
│ │ └── sub-045_T2w_seg-manual.nii.gz
│ ├── sub-046
│ └── anat
│ ├── sub-046_T2w_lesion-manual.json
│ ├── sub-046_T2w_lesion-manual.nii.gz
│ ├── sub-046_T2w_seg-manual.json
│ └── sub-046_T2w_seg-manual.nii.gz
...
Example of the output nnUNetv2 dataset structure:
├── dataset.json
├── imagesTr
│ ├── MyDataset-sub-046_000_0000.nii.gz
...
├── imagesTs
│ ├── MyDataset-sub-045_000_0000.nii.gz
...
├── labelsTr
│ ├── MyDataset-sub-046_000.nii.gz
...
└── labelsTs
├── MyDataset-sub-045_000.nii.gz
...
Usage example:
python convert_bids_to_nnUNetv2.py --path-data ~/data/dataset --path-out ~/data/dataset-nnunet
--dataset-name MyDataset --dataset-number 501 --split 0.8 0.2 --seed 99 --copy False
Expand All @@ -25,7 +65,6 @@
import pandas as pd
from loguru import logger
from sklearn.model_selection import train_test_split
import nibabel as nib
import numpy as np


Expand All @@ -40,6 +79,8 @@ def get_parser():
# TODO accept multi value label
parser.add_argument('--label-suffix', type=str,
help='Label suffix. Example: lesion-manual or seg-manual, if None no label used')
parser.add_argument('--data-type', type=str, default='anat',
help='Type of BIDS dataset used. For example, anat, func, dwi or etc. Default: anat')
parser.add_argument('--dataset-name', '-dname', default='MyDataset', type=str,
help='Specify the task name. Example: MyDataset')
parser.add_argument('--dataset-number', '-dnum', default=501, type=int,
Expand All @@ -51,12 +92,12 @@ def get_parser():
help='Ratios of training (includes validation) and test splits lying between 0-1. '
'Example: --split 0.8 0.2')
parser.add_argument('--copy', '-cp', type=bool, default=False,
help='Making symlink (False) or copying (True) the files in the nnUNet dataset, '
'default = False. Example for symlink: --copy True')
help='If used, the files will be copied to the new structure. If not used, the symbolic links '
'will be created. Default: False. Example for copy: --copy True')
return parser


def convert_subject(root, subject, channel, contrast, label_suffix, path_out_images, path_out_labels, counter,
def convert_subject(root, subject, channel, contrast, label_suffix, data_type, path_out_images, path_out_labels, counter,
list_images, list_labels, is_ses, copy, DS_name, session=None):
"""Function to get image from original BIDS dataset modify if needed and place
it with a compatible name in nnUNet dataset.
Expand All @@ -83,13 +124,23 @@ def convert_subject(root, subject, channel, contrast, label_suffix, path_out_ima
"""
if is_ses:
subject_image_file = os.path.join(root, subject, session, 'anat', f"{subject}_{session}_{contrast}.nii.gz")
subject_label_file = os.path.join(root, 'derivatives', 'labels', subject, session, 'anat',
subject_image_file = os.path.join(root, subject, session, data_type, f"{subject}_{session}_{contrast}.nii.gz")
subject_label_file = os.path.join(root, 'derivatives', 'labels', subject, session, data_type,
f"{subject}_{session}_{contrast}_{label_suffix}.nii.gz")
sub_name = re.match(r'^([^_]+_[^_]+)', Path(subject_image_file).name).group(1)

elif data_type == 'func':
subject_directory = os.path.join(root, subject, data_type)
all_files = os.listdir(subject_directory)
subject_image_file = os.path.join(subject_directory, [f for f in all_files if f.endswith('nii.gz')][0])
subject_label_directory = os.path.join(root, 'derivatives', 'labels', subject, data_type)
all_label_files = os.listdir(subject_label_directory)
subject_label_file = os.path.join(subject_label_directory, [f for f in all_label_files if f.endswith('nii.gz')][0])
sub_name = re.match(r'^([^_]+)', Path(subject_image_file).name).group(1)

else:
subject_image_file = os.path.join(root, subject, 'anat', f"{subject}_{contrast}.nii.gz")
subject_label_file = os.path.join(root, 'derivatives', 'labels', subject, 'anat',
subject_image_file = os.path.join(root, subject, data_type, f"{subject}_{contrast}.nii.gz")
subject_label_file = os.path.join(root, 'derivatives', 'labels', subject, data_type,
f"{subject}_{contrast}_{label_suffix}.nii.gz")
sub_name = re.match(r'^([^_]+)', Path(subject_image_file).name).group(1)

Expand Down Expand Up @@ -137,6 +188,8 @@ def main():
label_suffix = args.label_suffix
if label_suffix is None:
print(f"No suffix label provided, ignoring label to create this dataset")

data_type = args.data_type

# create individual directories for train and test images and labels
path_out_imagesTr = Path(os.path.join(path_out, 'imagesTr'))
Expand Down Expand Up @@ -194,7 +247,7 @@ def main():
train_ctr = len(train_images)
for contrast in contrast_list:
train_images, train_labels = convert_subject(root, subject, channel_dict[contrast], contrast,
label_suffix, path_out_imagesTr, path_out_labelsTr,
label_suffix, data_type, path_out_imagesTr, path_out_labelsTr,
train_ctr + test_ctr, train_images, train_labels,
True, copy, DS_name, session)

Expand All @@ -204,7 +257,7 @@ def main():
train_ctr = len(train_images)
for contrast in contrast_list:
train_images, train_labels = convert_subject(root, subject, channel_dict[contrast], contrast,
label_suffix, path_out_imagesTr, path_out_labelsTr,
label_suffix, data_type, path_out_imagesTr, path_out_labelsTr,
train_ctr + test_ctr, train_images, train_labels,
False, copy, DS_name)

Expand All @@ -221,7 +274,7 @@ def main():
test_ctr = len(test_images)
for contrast in contrast_list:
test_images, test_labels = convert_subject(root, subject, channel_dict[contrast], contrast,
label_suffix, path_out_imagesTs, path_out_labelsTs,
label_suffix, data_type, path_out_imagesTs, path_out_labelsTs,
train_ctr + test_ctr, test_images, test_labels, True,
copy, DS_name, session)

Expand All @@ -231,7 +284,7 @@ def main():
test_ctr = len(test_images)
for contrast in contrast_list:
test_images, test_labels = convert_subject(root, subject, channel_dict[contrast], contrast,
label_suffix, path_out_imagesTs, path_out_labelsTs,
label_suffix, data_type, path_out_imagesTs, path_out_labelsTs,
train_ctr + test_ctr, test_images, test_labels, False,
copy, DS_name)

Expand Down
14 changes: 14 additions & 0 deletions dataset_conversion/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
joblib
loguru
nibabel
numpy
packaging
pandas
pyarrow
python-dateutil
pytz
scikit-learn
scipy
six
threadpoolctl
tzdata
137 changes: 137 additions & 0 deletions tests/test_convert_bids_to_nnUNetV2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#######################################################################
#
# Tests for the `dataset_conversion/test_convert_bids_to_nnUNetV2.py` script
#
# RUN BY:
# python -m unittest tests/test_convert_bids_to_nnUNetV2.py
#######################################################################

import unittest
from unittest.mock import patch
from dataset_conversion.convert_bids_to_nnUNetV2 import convert_subject, get_parser


class TestConvertBidsToNnunet(unittest.TestCase):
"""
Test the conversion of BIDS dataset to nnUNetV2 dataset with data_type = "anat"
"""

def setUp(self):
self.root = "/path/to/bids"
self.subject = "sub-001"
self.contrast = "T2w"
self.label_suffix = "lesion-manual"
self.data_type = "anat"
self.path_out_images = "/path/to/nnunet/imagesTr"
self.path_out_labels = "/path/to/nnunet/labelsTr"
self.counter = 0
self.list_images = []
self.list_labels = []
self.is_ses = False
self.copy = False
self.DS_name = "MyDataset"
self.session = None
self.channel = 0

@patch('os.path.exists')
@patch('os.symlink')
@patch('shutil.copy2')
def test_convert_subject(self, mock_copy, mock_symlink, mock_exists):
# Setup mock responses
mock_exists.side_effect = lambda x: True # Simulate that all files exist

# Execute the function
result_images, result_labels = convert_subject(
self.root, self.subject, self.channel, self.contrast, self.label_suffix,
self.data_type, self.path_out_images, self.path_out_labels, self.counter,
self.list_images, self.list_labels, self.is_ses, self.copy, self.DS_name,
self.session
)

# Assert conditions
self.assertEqual(len(result_images), 1)
self.assertEqual(len(result_labels), 1)
if self.copy:
mock_copy.assert_called()
else:
mock_symlink.assert_called()

def test_argument_parsing(self):
parser = get_parser()
args = parser.parse_args([
'--path-data', '/path/to/bids', '--path-out', '/path/to/nnunet',
'--contrast', 'T2w', '--label-suffix', 'lesion-manual',
'--data-type', 'anat', '--dataset-name', 'MyDataset',
'--dataset-number', '501', '--split', '0.8', '0.2', '--seed', '99', '--copy', 'True'
])
self.assertEqual(args.path_data, '/path/to/bids')
self.assertEqual(args.path_out, '/path/to/nnunet')
self.assertEqual(args.contrast, ['T2w'])
self.assertEqual(args.label_suffix, 'lesion-manual')
self.assertEqual(args.data_type, 'anat')
self.assertEqual(args.dataset_name, 'MyDataset')
self.assertEqual(args.dataset_number, 501)
self.assertEqual(args.split, [0.8, 0.2])
self.assertEqual(args.seed, 99)
self.assertEqual(args.copy, True)


class TestConvertBidsToNnunetFuncDataType(unittest.TestCase):
"""
Test the conversion of BIDS dataset to nnUNetV2 dataset with data_type = "func"
"""

def setUp(self):
# Setup common test data for the "func" data type scenario
self.root = "/path/to/bids"
self.subject = "sub-001"
self.contrast = "bold"
self.label_suffix = "task-rest"
self.data_type = "func"
self.path_out_images = "/path/to/nnunet/imagesTr"
self.path_out_labels = "/path/to/nnunet/labelsTr"
self.counter = 0
self.list_images = []
self.list_labels = []
self.is_ses = False
self.copy = False
self.DS_name = "MyDataset"
self.session = None
self.channel = 0

@patch('os.path.exists')
@patch('os.symlink')
@patch('shutil.copy2')
@patch('os.listdir')
def test_convert_subject_func(self, mock_listdir, mock_copy, mock_symlink, mock_exists):
# Mock the os.listdir to return files simulating a "func" directory structure
mock_listdir.side_effect = lambda x: ["sub-001_task-rest_bold.nii.gz"]
# Mock os.path.exists to simulate that all necessary files exist
mock_exists.side_effect = lambda x: True

# Execute the function with "func" data type
result_images, result_labels = convert_subject(
self.root, self.subject, self.channel, self.contrast, self.label_suffix,
self.data_type, self.path_out_images, self.path_out_labels, self.counter,
self.list_images, self.list_labels, self.is_ses, self.copy, self.DS_name,
self.session
)

# Assert conditions specific to "func" data type processing
self.assertEqual(len(result_images), 1, "Should have added one image path for 'func' data type")
self.assertEqual(len(result_labels), 1, "Should have added one label path for 'func' data type")

expected_image_path = f"{self.path_out_images}/{self.DS_name}-sub-001_{self.counter:03d}_{self.channel:04d}.nii.gz"
expected_label_path = f"{self.path_out_labels}/{self.DS_name}-sub-001_{self.counter:03d}.nii.gz"

self.assertIn(expected_image_path, result_images, "The image path for 'func' data type is not as expected")
self.assertIn(expected_label_path, result_labels, "The label path for 'func' data type is not as expected")

if self.copy:
mock_copy.assert_called()
else:
mock_symlink.assert_called()


if __name__ == '__main__':
unittest.main()

0 comments on commit 0bbeaa2

Please sign in to comment.