From 22ee0d4e03120973ad84e44df818a85593193bc8 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 5 Feb 2024 02:38:29 -0500 Subject: [PATCH 01/15] add dataset conversion and utils scripts --- .../convert_bids_to_nnUNetv2.py | 398 +++++++++ dataset-conversion/utils.py | 781 ++++++++++++++++++ 2 files changed, 1179 insertions(+) create mode 100644 dataset-conversion/convert_bids_to_nnUNetv2.py create mode 100644 dataset-conversion/utils.py diff --git a/dataset-conversion/convert_bids_to_nnUNetv2.py b/dataset-conversion/convert_bids_to_nnUNetv2.py new file mode 100644 index 0000000..9217b27 --- /dev/null +++ b/dataset-conversion/convert_bids_to_nnUNetv2.py @@ -0,0 +1,398 @@ +""" +Convert BIDS-structured datasets (dcm-zurich-lesions, dcm-zurich-lesions-20231115) to the nnUNetv2 dataset format. +Full details about the format can be found here: +https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md + +The script to be used on a single dataset or multiple datasets. + +The script in default creates region-based labels for segmenting both lesion and the spinal cord. + +Currently only supports the conversion of a single contrast. In case of multiple contrasts, the script should be +modified to include those as well. + +Note: the script performs RPI reorientation of the images and labels + +Usage example multiple datasets: + python convert_bids_to_nnUNetv2_praxis.py + --path-data ~/data/dcm-zurich-lesions ~/data/dcm-zurich-lesions-20231115 + --path-out ${nnUNet_raw} + -dname DCMlesions + -dnum 601 + --split 0.8 0.2 + --seed 50 + --region-based + +Usage example single dataset: + python convert_bids_to_nnUNetv2_praxis.py + --path-data ~/data/dcm-zurich-lesions + --path-out ${nnUNet_raw} + -dname DCMlesions + -dnum 601 + --split 0.8 0.2 + --seed 50 + --region-based + +Authors: Naga Karthik, Jan Valosek +""" + +import argparse +from pathlib import Path +import json +import os +import re +import shutil +import yaml +from collections import OrderedDict +from loguru import logger +from sklearn.model_selection import train_test_split +from utils import binarize_label, create_region_based_label, get_git_branch_and_commit, Image +from tqdm import tqdm + +import nibabel as nib + + +def get_parser(): + # parse command line arguments + parser = argparse.ArgumentParser(description='Convert BIDS-structured dataset to nnUNetV2 database format.') + parser.add_argument('--path-data', nargs='+', required=True, type=str, + help='Path to BIDS dataset(s) (list).') + parser.add_argument('--path-out', help='Path to output directory.', required=True) + parser.add_argument('--dataset-name', '-dname', default='DCMlesions', type=str, + help='Specify the task name.') + parser.add_argument('--dataset-number', '-dnum', default=601, type=int, + help='Specify the task number, has to be greater than 500 but less than 999. e.g 502') + parser.add_argument('--seed', default=42, type=int, + help='Seed to be used for the random number generator split into training and test sets.') + parser.add_argument('--region-based', action='store_true', default=True, + help='If set, the script will create labels for region-based nnUNet training. Default: True') + # argument that accepts a list of floats as train val test splits + parser.add_argument('--split', nargs='+', type=float, default=[0.8, 0.2], + help='Ratios of training (includes validation) and test splits lying between 0-1. Example: ' + '--split 0.8 0.2') + return parser + + +def get_region_based_label(subject_label_file, subject_image_file, sub_ses_name, thr=0.5): + # define path for sc seg file + subject_seg_file = subject_label_file.replace('_label-lesion', '_label-SC_mask-manual') + + # check if the seg file exists + if not os.path.exists(subject_seg_file): + logger.info(f"Spinal cord segmentation file for subject {sub_ses_name} does not exist. Skipping.") + return None + + # create region-based label + seg_lesion_nii = create_region_based_label(subject_label_file, subject_seg_file, subject_image_file, + sub_ses_name, thr=thr) + + # save the region-based label + combined_seg_file = subject_label_file.replace('_label-lesion', '_SC-lesion') + nib.save(seg_lesion_nii, combined_seg_file) + + return combined_seg_file + + +def create_directories(path_out, site): + """Create test directories for a specified site. + + Args: + path_out (str): Base output directory. + site (str): Site identifier, such as 'dcm-zurich-lesions + """ + paths = [Path(path_out, f'imagesTs_{site}'), + Path(path_out, f'labelsTs_{site}')] + + for path in paths: + path.mkdir(parents=True, exist_ok=True) + + +def find_site_in_path(path): + """Extracts site identifier from the given path. + + Args: + path (str): Input path containing a site identifier. + + Returns: + str: Extracted site identifier or None if not found. + """ + # Find 'dcm-zurich-lesions' or 'dcm-zurich-lesions-20231115' + match = re.search(r'dcm-zurich-lesions(-\d{8})?', path) + return match.group(0) if match else None + + +def create_yaml(train_niftis, test_nifitis, path_out, args, train_ctr, test_ctr, dataset_commits): + # create a yaml file containing the list of training and test niftis + niftis_dict = { + f"train": sorted(train_niftis), + f"test": sorted(test_nifitis) + } + + # write the train and test niftis to a yaml file + with open(os.path.join(path_out, f"train_test_split_seed{args.seed}.yaml"), "w") as outfile: + yaml.dump(niftis_dict, outfile, default_flow_style=False) + + # c.f. dataset json generation + # In nnUNet V2, dataset.json file has become much shorter. The description of the fields and changes + # can be found here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md#datasetjson + # this file can be automatically generated using the following code here: + # https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/dataset_conversion/generate_dataset_json.py + # example: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/dataset_conversion/Task055_SegTHOR.py + + json_dict = OrderedDict() + json_dict['name'] = args.dataset_name + json_dict['description'] = args.dataset_name + json_dict['reference'] = "TBD" + json_dict['licence'] = "TBD" + json_dict['release'] = "0.0" + json_dict['numTraining'] = train_ctr + json_dict['numTest'] = test_ctr + json_dict['seed_used'] = args.seed + json_dict['dataset_versions'] = dataset_commits + json_dict['image_orientation'] = "RPI" + + # The following keys are the most important ones. + """ + channel_names: + Channel names must map the index to the name of the channel. For BIDS, this refers to the contrast suffix. + { + 0: 'T1', + 1: 'CT' + } + Note that the channel names may influence the normalization scheme!! Learn more in the documentation. + + labels: + This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not. + Example regular labels: + { + 'background': 0, + 'left atrium': 1, + 'some other label': 2 + } + Example region-based training: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/region_based_training.md + { + 'background': 0, + 'whole tumor': (1, 2, 3), + 'tumor core': (2, 3), + 'enhancing tumor': 3 + } + Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background! + """ + + json_dict['channel_names'] = { + 0: "acq-sag_T2w", + } + + if not args.region_based: + json_dict['labels'] = { + "background": 0, + "lesion": 1, + } + else: + json_dict['labels'] = { + "background": 0, + "sc": [1, 2], + # "sc": 1, + "lesion": 2, + } + json_dict['regions_class_order'] = [1, 2] + + # Needed for finding the files correctly. IMPORTANT! File endings must match between images and segmentations! + json_dict['file_ending'] = ".nii.gz" + + # create dataset_description.json + json_object = json.dumps(json_dict, indent=4) + # write to dataset description + # nn-unet requires it to be "dataset.json" + dataset_dict_name = f"dataset.json" + with open(os.path.join(path_out, dataset_dict_name), "w") as outfile: + outfile.write(json_object) + + +def main(): + parser = get_parser() + args = parser.parse_args() + + train_ratio, test_ratio = args.split + path_out = Path(os.path.join(os.path.abspath(args.path_out), f'Dataset{args.dataset_number}_{args.dataset_name}')) + + # create individual directories for train and test images and labels + path_out_imagesTr = Path(os.path.join(path_out, 'imagesTr')) + path_out_labelsTr = Path(os.path.join(path_out, 'labelsTr')) + # create the training directories + Path(path_out).mkdir(parents=True, exist_ok=True) + Path(path_out_imagesTr).mkdir(parents=True, exist_ok=True) + Path(path_out_labelsTr).mkdir(parents=True, exist_ok=True) + + # save output to a log file + logger.add(os.path.join(path_out, "logs.txt"), rotation="10 MB", level="INFO") + + # Check if dataset paths exist + for path in args.path_data: + if not os.path.exists(path): + raise ValueError(f"Path {path} does not exist.") + + # Get sites from the input paths + sites = set(find_site_in_path(path) for path in args.path_data if find_site_in_path(path)) + # Single site + if len(sites) == 1: + create_directories(path_out, sites.pop()) + # Multiple sites + else: + for site in sites: + create_directories(path_out, site) + + all_lesion_files, train_images, test_images = [], {}, {} + # temp dict for storing dataset commits + dataset_commits = {} + + # loop over the datasets + for dataset in args.path_data: + root = Path(dataset) + + # get the git branch and commit ID of the dataset + dataset_name = os.path.basename(os.path.normpath(dataset)) + branch, commit = get_git_branch_and_commit(dataset) + dataset_commits[dataset_name] = f"git-{branch}-{commit}" + + # get recursively all GT '_label-lesion' files + lesion_files = [str(path) for path in root.rglob('*_label-lesion.nii.gz')] + + # add to the list of all subjects + all_lesion_files.extend(lesion_files) + + # Get the training and test splits + tr_subs, te_subs = train_test_split(lesion_files, test_size=test_ratio, random_state=args.seed) + + # update the train and test images dicts with the key as the subject and value as the path to the subject + train_images.update({sub: os.path.join(root, sub) for sub in tr_subs}) + test_images.update({sub: os.path.join(root, sub) for sub in te_subs}) + + logger.info(f"Found subjects in the training set (combining all datasets): {len(train_images)}") + logger.info(f"Found subjects in the test set (combining all datasets): {len(test_images)}") + # Print test images for each site + for site in sites: + logger.info(f"Test subjects in {site}: {len([sub for sub in test_images if site in sub])}") + + # print version of each dataset in a separate line + for dataset_name, dataset_commit in dataset_commits.items(): + logger.info(f"{dataset_name} dataset version: {dataset_commit}") + + # Counters for train and test sets + train_ctr, test_ctr = 0, 0 + train_niftis, test_nifitis = [], [] + # Loop over all images + for subject_label_file in tqdm(all_lesion_files, desc="Iterating over all images"): + + # Construct path to the background image + subject_image_file = subject_label_file.replace('/derivatives/labels', '').replace('_label-lesion', '') + + # Train images + if subject_label_file in train_images.keys(): + + train_ctr += 1 + # add the subject image file to the list of training niftis + train_niftis.append(os.path.basename(subject_image_file)) + + # create the new convention names for nnunet + sub_name = f"{str(Path(subject_image_file).name).replace('.nii.gz', '')}" + + subject_image_file_nnunet = os.path.join(path_out_imagesTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}_0000.nii.gz") + subject_label_file_nnunet = os.path.join(path_out_labelsTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}.nii.gz") + + # use region-based labels if required + if args.region_based: + # overwritten the subject_label_file with the region-based label + subject_label_file = get_region_based_label(subject_label_file, + subject_image_file, sub_name, thr=0.5) + if subject_label_file is None: + print(f"Skipping since the region-based label could not be generated") + continue + + # copy the files to new structure + shutil.copyfile(subject_image_file, subject_image_file_nnunet) + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + + # convert the image and label to RPI using the Image class + image = Image(subject_image_file_nnunet) + image.change_orientation("RPI") + image.save(subject_image_file_nnunet) + + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + # binarize the label file only if region-based training is not set (since the region-based labels are + # already binarized) + if not args.region_based: + binarize_label(subject_image_file_nnunet, subject_label_file_nnunet) + + # Test images + elif subject_label_file in test_images: + + test_ctr += 1 + # add the image file to the list of testing niftis + test_nifitis.append(os.path.basename(subject_image_file)) + + # create the new convention names for nnunet + sub_name = f"{str(Path(subject_image_file).name).replace('.nii.gz', '')}" + + subject_image_file_nnunet = os.path.join(Path(path_out, + f'imagesTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}_0000.nii.gz') + subject_label_file_nnunet = os.path.join(Path(path_out, + f'labelsTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}.nii.gz') + + # use region-based labels if required + if args.region_based and find_site_in_path(test_images[subject_label_file]) != 'site_014': + # overwritten the subject_label_file with the region-based label + subject_label_file = get_region_based_label(subject_label_file, + subject_image_file, sub_name, thr=0.5) + if subject_label_file is None: + continue + + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + print(f"\nCopying {subject_label_file} to {subject_label_file_nnunet}") + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + # copy the files to new structure + shutil.copyfile(subject_image_file, subject_image_file_nnunet) + print(f"\nCopying {subject_image_file} to {subject_image_file_nnunet}") + # convert the image and label to RPI using the Image class + image = Image(subject_image_file_nnunet) + image.change_orientation("RPI") + image.save(subject_image_file_nnunet) + + # binarize the label file only if region-based training is not set (since the region-based labels are + # already binarized) + if not args.region_based: + binarize_label(subject_image_file_nnunet, subject_label_file_nnunet) + + else: + print("Skipping file, could not be located in the Train or Test splits split.", subject_label_file) + + logger.info(f"----- Dataset conversion finished! -----") + logger.info(f"Number of training and validation images (across all sites): {train_ctr}") + logger.info(f"Number of test images (across all sites): {test_ctr}") + # Get number of test images per site + test_images_per_site = {} + for test_subject in test_images: + site = find_site_in_path(test_subject) + if site in test_images_per_site: + test_images_per_site[site] += 1 + else: + test_images_per_site[site] = 1 + # Print number of test images per site + for site, num_images in test_images_per_site.items(): + logger.info(f"Number of test images in {site}: {num_images}") + + # create the yaml file containing the train and test niftis + create_yaml(train_niftis, test_nifitis, path_out, args, train_ctr, test_ctr, dataset_commits) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dataset-conversion/utils.py b/dataset-conversion/utils.py new file mode 100644 index 0000000..ca184d3 --- /dev/null +++ b/dataset-conversion/utils.py @@ -0,0 +1,781 @@ +import os +import nibabel as nib +import numpy as np +import logging +from copy import deepcopy +import subprocess + +logger = logging.getLogger(__name__) + + +def binarize_label(subject_path, label_path): + label_npy = nib.load(label_path).get_fdata() + threshold = 0.5 + label_npy = np.where(label_npy > threshold, 1, 0) + ref = nib.load(subject_path) + label_bin = nib.Nifti1Image(label_npy, ref.affine, ref.header) + # overwrite the original label file with the binarized version + nib.save(label_bin, label_path) + + +def create_region_based_label(lesion_label_file, seg_label_file, image_file, sub_ses_name, thr=0.5): + """ + Creates region-based labels for nnUNet training. The regions are: + 0: background + 1: spinal cord seg + 2: lesion seg + """ + # load the labels + lesion_label_npy = nib.load(lesion_label_file).get_fdata() + seg_label_npy = nib.load(seg_label_file).get_fdata() + + # binarize the labels + lesion_label_npy = np.where(lesion_label_npy > thr, 1, 0) + seg_label_npy = np.where(seg_label_npy > thr, 1, 0) + + # check if the shapes of the labels match + assert lesion_label_npy.shape == seg_label_npy.shape, \ + f'Shape mismatch between lesion label and segmentation label for subject {sub_ses_name}. Check the labels.' + + # create a new label array with the same shape as the original labels + label_npy = np.zeros(lesion_label_npy.shape, dtype=np.int16) + # spinal cord + label_npy[seg_label_npy == 1] = 1 + # lesion seg + label_npy[lesion_label_npy == 1] = 2 + # TODO: what happens when the subject has no lesion? + + # print unique values in the label array + # print(f'Unique values in the label array for subject {sub_ses_name}: {np.unique(label_npy)}') + + # save the new label file + ref = nib.load(image_file) + label_nii = nib.Nifti1Image(label_npy, ref.affine, ref.header) + + return label_nii + + +def get_git_branch_and_commit(dataset_path=None): + """ + :return: git branch and commit ID, with trailing '*' if modified + Taken from: https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/utils/sys.py#L476 + and https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/utils/sys.py#L461 + """ + + # branch info + b = subprocess.Popen(["git", "rev-parse", "--abbrev-ref", "HEAD"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, cwd=dataset_path) + b_output, _ = b.communicate() + b_status = b.returncode + + if b_status == 0: + branch = b_output.decode().strip() + else: + branch = "!?!" + + # commit info + p = subprocess.Popen(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=dataset_path) + output, _ = p.communicate() + status = p.returncode + if status == 0: + commit = output.decode().strip() + else: + commit = "?!?" + + p = subprocess.Popen(["git", "status", "--porcelain"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=dataset_path) + output, _ = p.communicate() + status = p.returncode + if status == 0: + unclean = True + for line in output.decode().strip().splitlines(): + line = line.rstrip() + if line.startswith("??"): # ignore ignored files, they can't hurt + continue + break + else: + unclean = False + if unclean: + commit += "*" + + return branch, commit + + +class Image(object): + """ + Compact version of SCT's Image Class (https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/image.py#L245) + Create an object that behaves similarly to nibabel's image object. Useful additions include: dims, change_orientation and getNonZeroCoordinates. + Taken from: https://github.com/ivadomed/utilities/blob/main/scripts/image.py + Changed default verbosity to 0. + """ + + def __init__(self, param=None, hdr=None, orientation=None, absolutepath=None, dim=None): + """ + :param param: string indicating a path to a image file or an `Image` object. + """ + + # initialization of all parameters + self.affine = None + self.data = None + self._path = None + self.ext = "" + + if absolutepath is not None: + self._path = os.path.abspath(absolutepath) + + # Case 1: load an image from file + if isinstance(param, str): + self.loadFromPath(param) + # Case 2: create a copy of an existing `Image` object + elif isinstance(param, type(self)): + self.copy(param) + # Case 3: create a blank image from a list of dimensions + elif isinstance(param, list): + self.data = np.zeros(param) + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + # Case 4: create an image from an existing data array + elif isinstance(param, (np.ndarray, np.generic)): + self.data = param + self.hdr = hdr.copy() if hdr is not None else nib.Nifti1Header() + self.hdr.set_data_shape(self.data.shape) + else: + raise TypeError('Image constructor takes at least one argument.') + + # Fix any mismatch between the array's datatype and the header datatype + self.fix_header_dtype() + + @property + def dim(self): + return get_dimension(self) + + @property + def orientation(self): + return get_orientation(self) + + @property + def absolutepath(self): + """ + Storage path (either actual or potential) + + Notes: + + - As several tools perform chdir() it's very important to have absolute paths + - When set, if relative: + + - If it already existed, it becomes a new basename in the old dirname + - Else, it becomes absolute (shortcut) + + Usually not directly touched (use `Image.save`), but in some cases it's + the best way to set it. + """ + return self._path + + @absolutepath.setter + def absolutepath(self, value): + if value is None: + self._path = None + return + elif not os.path.isabs(value) and self._path is not None: + value = os.path.join(os.path.dirname(self._path), value) + elif not os.path.isabs(value): + value = os.path.abspath(value) + self._path = value + + @property + def header(self): + return self.hdr + + @header.setter + def header(self, value): + self.hdr = value + + def __deepcopy__(self, memo): + return type(self)(deepcopy(self.data, memo), deepcopy(self.hdr, memo), deepcopy(self.orientation, memo), deepcopy(self.absolutepath, memo), deepcopy(self.dim, memo)) + + def copy(self, image=None): + if image is not None: + self.affine = deepcopy(image.affine) + self.data = deepcopy(image.data) + self.hdr = deepcopy(image.hdr) + self._path = deepcopy(image._path) + else: + return deepcopy(self) + + def loadFromPath(self, path): + """ + This function load an image from an absolute path using nibabel library + + :param path: path of the file from which the image will be loaded + :return: + """ + + self.absolutepath = os.path.abspath(path) + im_file = nib.load(self.absolutepath, mmap=True) + self.affine = im_file.affine.copy() + self.data = np.asanyarray(im_file.dataobj) + self.hdr = im_file.header.copy() + if path != self.absolutepath: + logger.debug("Loaded %s (%s) orientation %s shape %s", path, self.absolutepath, self.orientation, self.data.shape) + else: + logger.debug("Loaded %s orientation %s shape %s", path, self.orientation, self.data.shape) + + def change_orientation(self, orientation, inverse=False): + """ + Change orientation on image (in-place). + + :param orientation: orientation string (SCT "from" convention) + + :param inverse: if you think backwards, use this to specify that you actually\ + want to transform *from* the specified orientation, not *to*\ + it. + + """ + change_orientation(self, orientation, self, inverse=inverse) + return self + + def getNonZeroCoordinates(self, sorting=None, reverse_coord=False): + """ + This function return all the non-zero coordinates that the image contains. + Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value' + If reverse_coord is True, coordinate are sorted from larger to smaller. + + Removed Coordinate object + """ + n_dim = 1 + if self.dim[3] == 1: + n_dim = 3 + else: + n_dim = 4 + if self.dim[2] == 1: + n_dim = 2 + + if n_dim == 3: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]] for i in range(0, len(X))] + elif n_dim == 2: + try: + X, Y = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i]]] for i in range(0, len(X))] + except ValueError: + X, Y, Z = (self.data > 0).nonzero() + list_coordinates = [[X[i], Y[i], 0, self.data[X[i], Y[i], 0]] for i in range(0, len(X))] + + if sorting is not None: + if reverse_coord not in [True, False]: + raise ValueError('reverse_coord parameter must be a boolean') + + if sorting == 'x': + list_coordinates = sorted(list_coordinates, key=lambda el: el[0], reverse=reverse_coord) + elif sorting == 'y': + list_coordinates = sorted(list_coordinates, key=lambda el: el[1], reverse=reverse_coord) + elif sorting == 'z': + list_coordinates = sorted(list_coordinates, key=lambda el: el[2], reverse=reverse_coord) + elif sorting == 'value': + list_coordinates = sorted(list_coordinates, key=lambda el: el[3], reverse=reverse_coord) + else: + raise ValueError("sorting parameter must be either 'x', 'y', 'z' or 'value'") + + return list_coordinates + + def change_type(self, dtype): + """ + Change data type on image. + + Note: the image path is voided. + """ + change_type(self, dtype, self) + return self + + def fix_header_dtype(self): + """ + Change the header dtype to the match the datatype of the array. + """ + # Using bool for nibabel headers is unsupported, so use uint8 instead: + # `nibabel.spatialimages.HeaderDataError: data dtype "bool" not supported` + dtype_data = self.data.dtype + if dtype_data == bool: + dtype_data = np.uint8 + + dtype_header = self.hdr.get_data_dtype() + if dtype_header != dtype_data: + logger.warning(f"Image header specifies datatype '{dtype_header}', but array is of type " + f"'{dtype_data}'. Header metadata will be overwritten to use '{dtype_data}'.") + self.hdr.set_data_dtype(dtype_data) + + def save(self, path=None, dtype=None, verbose=0, mutable=False): + """ + Write an image in a nifti file + + :param path: Where to save the data, if None it will be taken from the\ + absolutepath member.\ + If path is a directory, will save to a file under this directory\ + with the basename from the absolutepath member. + + :param dtype: if not set, the image is saved in the same type as input data\ + if 'minimize', image storage space is minimized\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + + :param mutable: whether to update members with newly created path or dtype + """ + if mutable: # do all modifications in-place + # Case 1: `path` not specified + if path is None: + if self.absolutepath: # Fallback to the original filepath + path = self.absolutepath + else: + raise ValueError("Don't know where to save the image (no absolutepath or path parameter)") + # Case 2: `path` points to an existing directory + elif os.path.isdir(path): + if self.absolutepath: # Use the original filename, but save to the directory specified by `path` + path = os.path.join(os.path.abspath(path), os.path.basename(self.absolutepath)) + else: + raise ValueError("Don't know where to save the image (path parameter is dir, but absolutepath is " + "missing)") + # Case 3: `path` points to a file (or a *nonexistent* directory) so use its value as-is + # (We're okay with letting nonexistent directories slip through, because it's difficult to distinguish + # between nonexistent directories and nonexistent files. Plus, `nibabel` will catch any further errors.) + else: + pass + + if os.path.isfile(path) and verbose: + logger.warning("File %s already exists. Will overwrite it.", path) + if os.path.isabs(path): + logger.debug("Saving image to %s orientation %s shape %s", + path, self.orientation, self.data.shape) + else: + logger.debug("Saving image to %s (%s) orientation %s shape %s", + path, os.path.abspath(path), self.orientation, self.data.shape) + + # Now that `path` has been set and log messages have been written, we can assign it to the image itself + self.absolutepath = os.path.abspath(path) + + if dtype is not None: + self.change_type(dtype) + + if self.hdr is not None: + self.hdr.set_data_shape(self.data.shape) + self.fix_header_dtype() + + # nb. that copy() is important because if it were a memory map, save() would corrupt it + dataobj = self.data.copy() + affine = None + header = self.hdr.copy() if self.hdr is not None else None + nib.save(nib.nifti1.Nifti1Image(dataobj, affine, header), self.absolutepath) + if not os.path.isfile(self.absolutepath): + raise RuntimeError(f"Couldn't save image to {self.absolutepath}") + else: + # if we're not operating in-place, then make any required modifications on a throw-away copy + self.copy().save(path, dtype, verbose, mutable=True) + return self + + +class SlicerOneAxis(object): + """ + Image slicer to use when you don't care about the 2D slice orientation, + and don't want to specify them. + The slicer will just iterate through the right axis that corresponds to + its specification. + + Can help getting ranges and slice indices. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + + def __init__(self, im, axis="IS"): + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + axis_labels = "LRPAIS" + if len(axis) != 2: + raise ValueError() + if axis[0] not in axis_labels: + raise ValueError() + if axis[1] not in axis_labels: + raise ValueError() + if axis[0] != opposite_character[axis[1]]: + raise ValueError() + + for idx_axis in range(2): + dim_nr = im.orientation.find(axis[idx_axis]) + if dim_nr != -1: + break + if dim_nr == -1: + raise ValueError() + + # SCT convention + from_dir = im.orientation[dim_nr] + self.direction = +1 if axis[0] == from_dir else -1 + self.nb_slices = im.dim[dim_nr] + self.im = im + self.axis = axis + self._slice = lambda idx: tuple([(idx if x in axis else slice(None)) for x in im.orientation]) + + def __len__(self): + return self.nb_slices + + def __getitem__(self, idx): + """ + + :return: an image slice, at slicing index idx + :param idx: slicing index (according to the slicing direction) + """ + if isinstance(idx, slice): + raise NotImplementedError() + + if idx >= self.nb_slices: + raise IndexError("I just have {} slices!".format(self.nb_slices)) + + if self.direction == -1: + idx = self.nb_slices - 1 - idx + + return self.im.data[self._slice(idx)] + +def get_dimension(im_file, verbose=1): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + Get dimension from Image or nibabel object. Manages 2D, 3D or 4D images. + + :param: im_file: Image or nibabel object + :return: nx, ny, nz, nt, px, py, pz, pt + """ + if not isinstance(im_file, (nib.nifti1.Nifti1Image, Image)): + raise TypeError("The provided image file is neither a nibabel.nifti1.Nifti1Image instance nor an Image instance") + # initializating ndims [nx, ny, nz, nt] and pdims [px, py, pz, pt] + ndims = [1, 1, 1, 1] + pdims = [1, 1, 1, 1] + data_shape = im_file.header.get_data_shape() + zooms = im_file.header.get_zooms() + for i in range(min(len(data_shape), 4)): + ndims[i] = data_shape[i] + pdims[i] = zooms[i] + return *ndims, *pdims + + +def change_orientation(im_src, orientation, im_dst=None, inverse=False): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src: source image + :param orientation: orientation string (SCT "from" convention) + :param im_dst: destination image (can be the source image for in-place + operation, can be unset to generate one) + :param inverse: if you think backwards, use this to specify that you actually + want to transform *from* the specified orientation, not *to* it. + :return: an image with changed orientation + + .. note:: + - the resulting image has no path member set + - if the source image is < 3D, it is reshaped to 3D and the destination is 3D + """ + + if len(im_src.data.shape) < 3: + pass # Will reshape to 3D + elif len(im_src.data.shape) == 3: + pass # OK, standard 3D volume + elif len(im_src.data.shape) == 4: + pass # OK, standard 4D volume + elif len(im_src.data.shape) == 5 and im_src.header.get_intent()[0] == "vector": + pass # OK, physical displacement field + else: + raise NotImplementedError("Don't know how to change orientation for this image") + + im_src_orientation = im_src.orientation + im_dst_orientation = orientation + if inverse: + im_src_orientation, im_dst_orientation = im_dst_orientation, im_src_orientation + + perm, inversion = _get_permutations(im_src_orientation, im_dst_orientation) + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + im_src_data = im_src.data + if len(im_src_data.shape) < 3: + im_src_data = im_src_data.reshape(tuple(list(im_src_data.shape) + ([1] * (3 - len(im_src_data.shape))))) + + # Update data by performing inversions and swaps + + # axes inversion (flip) + data = im_src_data[::inversion[0], ::inversion[1], ::inversion[2]] + + # axes manipulations (transpose) + if perm == [1, 0, 2]: + data = np.swapaxes(data, 0, 1) + elif perm == [2, 1, 0]: + data = np.swapaxes(data, 0, 2) + elif perm == [0, 2, 1]: + data = np.swapaxes(data, 1, 2) + elif perm == [2, 0, 1]: + data = np.swapaxes(data, 0, 2) # transform [2, 0, 1] to [1, 0, 2] + data = np.swapaxes(data, 0, 1) # transform [1, 0, 2] to [0, 1, 2] + elif perm == [1, 2, 0]: + data = np.swapaxes(data, 0, 2) # transform [1, 2, 0] to [0, 2, 1] + data = np.swapaxes(data, 1, 2) # transform [0, 2, 1] to [0, 1, 2] + elif perm == [0, 1, 2]: + # do nothing + pass + else: + raise NotImplementedError() + + # Update header + + im_src_aff = im_src.hdr.get_best_affine() + aff = nib.orientations.inv_ornt_aff( + np.array((perm, inversion)).T, + im_src_data.shape) + im_dst_aff = np.matmul(im_src_aff, aff) + + im_dst.header.set_qform(im_dst_aff) + im_dst.header.set_sform(im_dst_aff) + im_dst.header.set_data_shape(data.shape) + im_dst.data = data + + return im_dst + + +def _get_permutations(im_src_orientation, im_dst_orientation): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im_src_orientation str: Orientation of source image. Example: 'RPI' + :param im_dest_orientation str: Orientation of destination image. Example: 'SAL' + :return: list of axes permutations and list of inversions to achieve an orientation change + """ + + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + + perm = [0, 1, 2] + inversion = [1, 1, 1] + for i, character in enumerate(im_src_orientation): + try: + perm[i] = im_dst_orientation.index(character) + except ValueError: + perm[i] = im_dst_orientation.index(opposite_character[character]) + inversion[i] = -1 + + return perm, inversion + + +def get_orientation(im): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :param im: an Image + :return: reference space string (ie. what's in Image.orientation) + """ + res = "".join(nib.orientations.aff2axcodes(im.hdr.get_best_affine())) + return orientation_string_nib2sct(res) + + +def orientation_string_nib2sct(s): + """ + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + + :return: SCT reference space code from nibabel one + """ + opposite_character = {'L': 'R', 'R': 'L', 'A': 'P', 'P': 'A', 'I': 'S', 'S': 'I'} + return "".join([opposite_character[x] for x in s]) + + +def change_type(im_src, dtype, im_dst=None): + """ + Change the voxel type of the image + + :param dtype: if not set, the image is saved in standard type\ + if 'minimize', image space is minimize\ + if 'minimize_int', image space is minimize and values are approximated to integers\ + (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),\ + (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),\ + (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),\ + (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),\ + (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),\ + (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),\ + (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),\ + (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),\ + (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),\ + (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),\ + (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),\ + (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),\ + (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),\ + (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"), + :return: + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + + if im_dst is None: + im_dst = im_src.copy() + im_dst._path = None + + if dtype is None: + return im_dst + + # get min/max from input image + min_in = np.nanmin(im_src.data) + max_in = np.nanmax(im_src.data) + + # find optimum type for the input image + if dtype in ('minimize', 'minimize_int'): + # warning: does not take intensity resolution into account, neither complex voxels + + # check if voxel values are real or integer + isInteger = True + if dtype == 'minimize': + for vox in im_src.data.flatten(): + if int(vox) != vox: + isInteger = False + break + + if isInteger: + if min_in >= 0: # unsigned + if max_in <= np.iinfo(np.uint8).max: + dtype = np.uint8 + elif max_in <= np.iinfo(np.uint16): + dtype = np.uint16 + elif max_in <= np.iinfo(np.uint32).max: + dtype = np.uint32 + elif max_in <= np.iinfo(np.uint64).max: + dtype = np.uint64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + if max_in <= np.iinfo(np.int8).max and min_in >= np.iinfo(np.int8).min: + dtype = np.int8 + elif max_in <= np.iinfo(np.int16).max and min_in >= np.iinfo(np.int16).min: + dtype = np.int16 + elif max_in <= np.iinfo(np.int32).max and min_in >= np.iinfo(np.int32).min: + dtype = np.int32 + elif max_in <= np.iinfo(np.int64).max and min_in >= np.iinfo(np.int64).min: + dtype = np.int64 + else: + raise ValueError("Maximum value of the image is to big to be represented.") + else: + # if max_in <= np.finfo(np.float16).max and min_in >= np.finfo(np.float16).min: + # type = 'np.float16' # not supported by nibabel + if max_in <= np.finfo(np.float32).max and min_in >= np.finfo(np.float32).min: + dtype = np.float32 + elif max_in <= np.finfo(np.float64).max and min_in >= np.finfo(np.float64).min: + dtype = np.float64 + + dtype = to_dtype(dtype) + else: + dtype = to_dtype(dtype) + + # if output type is int, check if it needs intensity rescaling + if "int" in dtype.name: + # get min/max from output type + min_out = np.iinfo(dtype).min + max_out = np.iinfo(dtype).max + # before rescaling, check if there would be an intensity overflow + + if (min_in < min_out) or (max_in > max_out): + # This condition is important for binary images since we do not want to scale them + logger.warning(f"To avoid intensity overflow due to convertion to +{dtype.name}+, intensity will be rescaled to the maximum quantization scale") + # rescale intensity + data_rescaled = im_src.data * (max_out - min_out) / (max_in - min_in) + im_dst.data = data_rescaled - (data_rescaled.min() - min_out) + + # change type of data in both numpy array and nifti header + im_dst.data = getattr(np, dtype.name)(im_dst.data) + im_dst.hdr.set_data_dtype(dtype) + return im_dst + + +def to_dtype(dtype): + """ + Take a dtypeification and return an np.dtype + + :param dtype: dtypeification (string or np.dtype or None are supported for now) + :return: dtype or None + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/ + """ + # TODO add more or filter on things supported by nibabel + + if dtype is None: + return None + if isinstance(dtype, type): + if isinstance(dtype(0).dtype, np.dtype): + return dtype(0).dtype + if isinstance(dtype, np.dtype): + return dtype + if isinstance(dtype, str): + return np.dtype(dtype) + + raise TypeError("data type {}: {} not understood".format(dtype.__class__, dtype)) + + +def zeros_like(img, dtype=None): + """ + + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, filled with zeros + + Similar to numpy.zeros_like(), the goal of the function is to show the developer's + intent and avoid doing a copy, which is slower than initialization with a constant. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + zimg = Image(np.zeros_like(img.data), hdr=img.hdr.copy()) + if dtype is not None: + zimg.change_type(dtype) + return zimg + + +def empty_like(img, dtype=None): + """ + :param img: reference image + :param dtype: desired data type (optional) + :return: an Image with the same shape and header, whose data is uninitialized + + Similar to numpy.empty_like(), the goal of the function is to show the developer's + intent and avoid touching the allocated memory, because it will be written to + afterwards. + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + dst = change_type(img, dtype) + return dst + + +def find_zmin_zmax(im, threshold=0.1): + """ + Find the min (and max) z-slice index below which (and above which) slices only have voxels below a given threshold. + + :param im: Image object + :param threshold: threshold to apply before looking for zmin/zmax, typically corresponding to noise level. + :return: [zmin, zmax] + + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox/image.py + """ + slicer = SlicerOneAxis(im, axis="IS") + + # Make sure image is not empty + if not np.any(slicer): + logger.error('Input image is empty') + + # Iterate from bottom to top until we find data + for zmin in range(0, len(slicer)): + if np.any(slicer[zmin] > threshold): + break + + # Conversely from top to bottom + for zmax in range(len(slicer) - 1, zmin, -1): + if np.any(slicer[zmax] > threshold): + break + + return zmin, zmax \ No newline at end of file From 283a623b8c53bfc32679419f1d4a4cc236216898 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 5 Feb 2024 02:39:30 -0500 Subject: [PATCH 02/15] add conversion script for dcm-zuri pretraining --- .../convert_bids_to_nnUNetv2_pretrain.py | 404 ++++++++++++++++++ 1 file changed, 404 insertions(+) create mode 100644 dataset-conversion/convert_bids_to_nnUNetv2_pretrain.py diff --git a/dataset-conversion/convert_bids_to_nnUNetv2_pretrain.py b/dataset-conversion/convert_bids_to_nnUNetv2_pretrain.py new file mode 100644 index 0000000..04574d1 --- /dev/null +++ b/dataset-conversion/convert_bids_to_nnUNetv2_pretrain.py @@ -0,0 +1,404 @@ +""" +Convert BIDS-structured datasets (dcm-zurich) to the nnUNetv2 dataset format. +Full details about the format can be found here: +https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md + +The script to be used on a single dataset or multiple datasets. + +The script in default creates region-based labels for segmenting both lesion and the spinal cord. + +Currently only supports the conversion of a single contrast. In case of multiple contrasts, the script should be +modified to include those as well. + +Note: the script performs RPI reorientation of the images and labels + +Usage example multiple datasets: + python convert_bids_to_nnUNetv2_praxis.py + --path-data ~/data/dcm-zurich-lesions ~/data/dcm-zurich-lesions-20231115 + --path-out ${nnUNet_raw} + -dname DCMlesions + -dnum 601 + --split 0.8 0.2 + --seed 50 + --region-based + +Usage example single dataset: + python convert_bids_to_nnUNetv2_praxis.py + --path-data ~/data/dcm-zurich-lesions + --path-out ${nnUNet_raw} + -dname DCMlesions + -dnum 601 + --split 0.8 0.2 + --seed 50 + --region-based + +Authors: Naga Karthik, Jan Valosek +""" + +import argparse +from pathlib import Path +import json +import os +import re +import shutil +import yaml +from collections import OrderedDict +from loguru import logger +from sklearn.model_selection import train_test_split +from utils import binarize_label, create_region_based_label, get_git_branch_and_commit, Image +from tqdm import tqdm + +import nibabel as nib + + +def get_parser(): + # parse command line arguments + parser = argparse.ArgumentParser(description='Convert BIDS-structured dataset to nnUNetV2 database format.') + parser.add_argument('--path-data', nargs='+', required=True, type=str, + help='Path to BIDS dataset(s) (list).') + parser.add_argument('--path-out', help='Path to output directory.', required=True) + parser.add_argument('--dataset-name', '-dname', default='DCMcompression', type=str, + help='Specify the task name.') + parser.add_argument('--dataset-number', '-dnum', default=601, type=int, + help='Specify the task number, has to be greater than 500 but less than 999. e.g 502') + parser.add_argument('--seed', default=42, type=int, + help='Seed to be used for the random number generator split into training and test sets.') + parser.add_argument('--region-based', action='store_true', default=False, + help='If set, the script will create labels for region-based nnUNet training. Default: False') + # argument that accepts a list of floats as train val test splits + parser.add_argument('--split', nargs='+', type=float, default=[0.8, 0.2], + help='Ratios of training (includes validation) and test splits lying between 0-1. Example: ' + '--split 0.8 0.2') + return parser + + +def get_region_based_label(subject_label_file, subject_image_file, sub_ses_name, thr=0.5): + # define path for sc seg file + subject_seg_file = subject_label_file.replace('_label-lesion', '_label-SC_mask-manual') + + # check if the seg file exists + if not os.path.exists(subject_seg_file): + logger.info(f"Spinal cord segmentation file for subject {sub_ses_name} does not exist. Skipping.") + return None + + # create region-based label + seg_lesion_nii = create_region_based_label(subject_label_file, subject_seg_file, subject_image_file, + sub_ses_name, thr=thr) + + # save the region-based label + combined_seg_file = subject_label_file.replace('_label-lesion', '_SC-lesion') + nib.save(seg_lesion_nii, combined_seg_file) + + return combined_seg_file + + +def create_directories(path_out, site): + """Create test directories for a specified site. + + Args: + path_out (str): Base output directory. + site (str): Site identifier, such as 'dcm-zurich-lesions + """ + paths = [Path(path_out, f'imagesTs_{site}'), + Path(path_out, f'labelsTs_{site}')] + + for path in paths: + path.mkdir(parents=True, exist_ok=True) + + +def find_site_in_path(path): + """Extracts site identifier from the given path. + + Args: + path (str): Input path containing a site identifier. + + Returns: + str: Extracted site identifier or None if not found. + """ + # Find 'dcm-zurich' + match = re.search(r'dcm-zurich?', path) + return match.group(0) if match else None + + +def create_yaml(train_niftis, test_nifitis, path_out, args, train_ctr, test_ctr, dataset_commits): + # create a yaml file containing the list of training and test niftis + niftis_dict = { + f"train": sorted(train_niftis), + f"test": sorted(test_nifitis) + } + + # write the train and test niftis to a yaml file + with open(os.path.join(path_out, f"train_test_split_seed{args.seed}.yaml"), "w") as outfile: + yaml.dump(niftis_dict, outfile, default_flow_style=False) + + # c.f. dataset json generation + # In nnUNet V2, dataset.json file has become much shorter. The description of the fields and changes + # can be found here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/dataset_format.md#datasetjson + # this file can be automatically generated using the following code here: + # https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/dataset_conversion/generate_dataset_json.py + # example: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/dataset_conversion/Task055_SegTHOR.py + + json_dict = OrderedDict() + json_dict['name'] = args.dataset_name + json_dict['description'] = args.dataset_name + json_dict['reference'] = "TBD" + json_dict['licence'] = "TBD" + json_dict['release'] = "0.0" + json_dict['numTraining'] = train_ctr + json_dict['numTest'] = test_ctr + json_dict['seed_used'] = args.seed + json_dict['dataset_versions'] = dataset_commits + json_dict['image_orientation'] = "RPI" + + # The following keys are the most important ones. + """ + channel_names: + Channel names must map the index to the name of the channel. For BIDS, this refers to the contrast suffix. + { + 0: 'T1', + 1: 'CT' + } + Note that the channel names may influence the normalization scheme!! Learn more in the documentation. + + labels: + This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not. + Example regular labels: + { + 'background': 0, + 'left atrium': 1, + 'some other label': 2 + } + Example region-based training: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/region_based_training.md + { + 'background': 0, + 'whole tumor': (1, 2, 3), + 'tumor core': (2, 3), + 'enhancing tumor': 3 + } + Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background! + """ + + json_dict['channel_names'] = { + 0: "acq-sag_T2w", + } + + if not args.region_based: + json_dict['labels'] = { + "background": 0, + "lesion": 1, + } + else: + json_dict['labels'] = { + "background": 0, + "sc": [1, 2], + # "sc": 1, + "lesion": 2, + } + json_dict['regions_class_order'] = [1, 2] + + # Needed for finding the files correctly. IMPORTANT! File endings must match between images and segmentations! + json_dict['file_ending'] = ".nii.gz" + + # create dataset_description.json + json_object = json.dumps(json_dict, indent=4) + # write to dataset description + # nn-unet requires it to be "dataset.json" + dataset_dict_name = f"dataset.json" + with open(os.path.join(path_out, dataset_dict_name), "w") as outfile: + outfile.write(json_object) + + +def main(): + parser = get_parser() + args = parser.parse_args() + + train_ratio, test_ratio = args.split + path_out = Path(os.path.join(os.path.abspath(args.path_out), f'Dataset{args.dataset_number}_{args.dataset_name}')) + + # create individual directories for train and test images and labels + path_out_imagesTr = Path(os.path.join(path_out, 'imagesTr')) + path_out_labelsTr = Path(os.path.join(path_out, 'labelsTr')) + # create the training directories + Path(path_out).mkdir(parents=True, exist_ok=True) + Path(path_out_imagesTr).mkdir(parents=True, exist_ok=True) + Path(path_out_labelsTr).mkdir(parents=True, exist_ok=True) + + # save output to a log file + logger.add(os.path.join(path_out, "logs.txt"), rotation="10 MB", level="INFO") + + # Check if dataset paths exist + for path in args.path_data: + if not os.path.exists(path): + raise ValueError(f"Path {path} does not exist.") + + # Get sites from the input paths + sites = set(find_site_in_path(path) for path in args.path_data if find_site_in_path(path)) + # Single site + if len(sites) == 1: + create_directories(path_out, sites.pop()) + # Multiple sites + else: + for site in sites: + create_directories(path_out, site) + + all_files, train_images, test_images = [], {}, {} + # temp dict for storing dataset commits + dataset_commits = {} + + # loop over the datasets + for dataset in args.path_data: + root = Path(dataset) + + # get the git branch and commit ID of the dataset + dataset_name = os.path.basename(os.path.normpath(dataset)) + branch, commit = get_git_branch_and_commit(dataset) + dataset_commits[dataset_name] = f"git-{branch}-{commit}" + + # get recursively all GT compression labels + label_files = [str(path) for path in root.rglob('*_label-compression-manual.nii.gz')] + + # add to the list of all subjects + all_files.extend(label_files) + + # Get the training and test splits + tr_subs, te_subs = train_test_split(label_files, test_size=test_ratio, random_state=args.seed) + + # update the train and test images dicts with the key as the subject and value as the path to the subject + train_images.update({sub: os.path.join(root, sub) for sub in tr_subs}) + test_images.update({sub: os.path.join(root, sub) for sub in te_subs}) + + logger.info(f"Found subjects in the training set (combining all datasets): {len(train_images)}") + logger.info(f"Found subjects in the test set (combining all datasets): {len(test_images)}") + # Print test images for each site + for site in sites: + logger.info(f"Test subjects in {site}: {len([sub for sub in test_images if site in sub])}") + + # print version of each dataset in a separate line + for dataset_name, dataset_commit in dataset_commits.items(): + logger.info(f"{dataset_name} dataset version: {dataset_commit}") + + # Counters for train and test sets + train_ctr, test_ctr = 0, 0 + train_niftis, test_nifitis = [], [] + # Loop over all images + for subject_label_file in tqdm(all_files, desc="Iterating over all images"): + + # Construct path to the background image + subject_image_file = subject_label_file.replace('/derivatives/labels', '').replace('_label-compression-manual', '') + + # Train images + if subject_label_file in train_images.keys(): + + train_ctr += 1 + # add the subject image file to the list of training niftis + train_niftis.append(os.path.basename(subject_image_file)) + + # create the new convention names for nnunet + sub_name = f"{str(Path(subject_image_file).name).replace('.nii.gz', '')}" + + subject_image_file_nnunet = os.path.join(path_out_imagesTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}_0000.nii.gz") + subject_label_file_nnunet = os.path.join(path_out_labelsTr, + f"{args.dataset_name}_{sub_name}_{train_ctr:03d}.nii.gz") + + # use region-based labels if required + if args.region_based: + # overwritten the subject_label_file with the region-based label + subject_label_file = get_region_based_label(subject_label_file, + subject_image_file, sub_name, thr=0.5) + if subject_label_file is None: + print(f"Skipping since the region-based label could not be generated") + continue + + # copy the files to new structure + shutil.copyfile(subject_image_file, subject_image_file_nnunet) + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + + # convert the image and label to RPI using the Image class + image = Image(subject_image_file_nnunet) + image.change_orientation("RPI") + image.save(subject_image_file_nnunet) + + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + # binarize the label file only if region-based training is not set (since the region-based labels are + # already binarized) + if not args.region_based: + binarize_label(subject_image_file_nnunet, subject_label_file_nnunet) + + # Test images + elif subject_label_file in test_images: + + test_ctr += 1 + # add the image file to the list of testing niftis + test_nifitis.append(os.path.basename(subject_image_file)) + + # create the new convention names for nnunet + sub_name = f"{str(Path(subject_image_file).name).replace('.nii.gz', '')}" + + subject_image_file_nnunet = os.path.join(Path(path_out, + f'imagesTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}_0000.nii.gz') + subject_label_file_nnunet = os.path.join(Path(path_out, + f'labelsTs_{find_site_in_path(test_images[subject_label_file])}'), + f'{args.dataset_name}_{sub_name}_{test_ctr:03d}.nii.gz') + + # use region-based labels if required + if args.region_based and find_site_in_path(test_images[subject_label_file]) != 'site_014': + # overwritten the subject_label_file with the region-based label + subject_label_file = get_region_based_label(subject_label_file, + subject_image_file, sub_name, thr=0.5) + if subject_label_file is None: + continue + + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + print("here") + print(f"\nCopying {subject_label_file} to {subject_label_file_nnunet}") + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + # copy the files to new structure + shutil.copyfile(subject_image_file, subject_image_file_nnunet) + # print(f"\nCopying {subject_image_file} to {subject_image_file_nnunet}") + # convert the image and label to RPI using the Image class + image = Image(subject_image_file_nnunet) + image.change_orientation("RPI") + image.save(subject_image_file_nnunet) + + # binarize the label file only if region-based training is not set (since the region-based labels are + # already binarized) + if not args.region_based: + shutil.copyfile(subject_label_file, subject_label_file_nnunet) + label = Image(subject_label_file_nnunet) + label.change_orientation("RPI") + label.save(subject_label_file_nnunet) + + binarize_label(subject_image_file_nnunet, subject_label_file_nnunet) + + else: + print("Skipping file, could not be located in the Train or Test splits split.", subject_label_file) + + logger.info(f"----- Dataset conversion finished! -----") + logger.info(f"Number of training and validation images (across all sites): {train_ctr}") + logger.info(f"Number of test images (across all sites): {test_ctr}") + # Get number of test images per site + test_images_per_site = {} + for test_subject in test_images: + site = find_site_in_path(test_subject) + if site in test_images_per_site: + test_images_per_site[site] += 1 + else: + test_images_per_site[site] = 1 + # Print number of test images per site + for site, num_images in test_images_per_site.items(): + logger.info(f"Number of test images in {site}: {num_images}") + + # create the yaml file containing the train and test niftis + create_yaml(train_niftis, test_nifitis, path_out, args, train_ctr, test_ctr, dataset_commits) + + +if __name__ == "__main__": + main() \ No newline at end of file From 464f502f4a884c710d366abedecde178f7600f4a Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Mon, 5 Feb 2024 02:40:19 -0500 Subject: [PATCH 03/15] add pre-training & finetuning script --- ...n_dcm_zurich_pretraining_and_finetuning.sh | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 nnunet/run_dcm_zurich_pretraining_and_finetuning.sh diff --git a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh new file mode 100644 index 0000000..ef75076 --- /dev/null +++ b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# +# This script combines pre-training on dcm-zurich for compression detection and +# fine-tuning on dcm-zurich-lesions for lesion segmentation. +# +# Assumes that the datasets for pretraining and finetuning already exist. +# These are most likely the outputs of `convert_bids_to_nnUNetv2*.py` scripts. +# +# Usage: +# cd ~/code/model-seg-dcm +# ./nnunet/run_dcm_zurich_pretraining_and_finetuning.sh +# +# Author: Jan Valosek, Naga Karthik +# + +# Uncomment for full verbose +# set -x + +# Immediately exit if error +set -e -o pipefail + +# Exit if user presses CTRL+C (Linux) or CMD+C (OSX) +trap "echo Caught Keyboard Interrupt within script. Exiting now.; exit" INT + +# Global variables +cuda_visible_devices=2 +folds=(3) +sites=(dcm-zurich-lesions dcm-zurich-lesions-20231115) +nnunet_trainer="nnUNetTrainer" +# nnunet_trainer="nnUNetTrainer_2000epochs" # default: nnUNetTrainer +configuration="3d_fullres" # for 2D training, use "2d" + +# Variables for pretraining on dcm-zurich (i.e. source dataset) +dataset_num_ptr="191" +dataset_name_ptr="Dataset${dataset_num_ptr}_dcmZurichPretrain" +dataset_git_annex_name="dcm-zurich" + +# Variables for finetuning on dcm-zurich-lesions (i.e. target dataset) +dataset_num_ftu="192" +dataset_name_ftu="Dataset${dataset_num_ftu}_dcmZurichLesionsFinetune" + + +echo "-------------------------------------------------------" +echo "Running plan_and_preprocess for ${dataset_name_ftu}" +echo "-------------------------------------------------------" +nnUNetv2_plan_and_preprocess -d ${dataset_num_ftu} --verify_dataset_integrity -c ${configuration} + +echo "-------------------------------------------------------" +echo "Running plan_and_preprocess for ${dataset_name_ptr}" +echo "-------------------------------------------------------" +nnUNetv2_plan_and_preprocess -d ${dataset_num_ptr} --verify_dataset_integrity -c ${configuration} + +echo "-------------------------------------------------------" +echo "Extracting dataset fingerprint for ${dataset_name_ptr}" +echo "-------------------------------------------------------" +nnUNetv2_extract_fingerprint -d ${dataset_num_ptr} + +echo "-------------------------------------------------------" +echo "Moving plans from ${dataset_name_ftu} to ${dataset_name_ptr}" +echo "-------------------------------------------------------" +nnUNetv2_move_plans_between_datasets -s ${dataset_num_ftu} -t ${dataset_num_ptr} -sp nnUNetPlans -tp nnUNetMovedPlans + +echo "-------------------------------------------------------" +echo "Running (only) preprocessing for ${dataset_name_ptr} after moving plans" +echo "-------------------------------------------------------" +nnUNetv2_preprocess -d ${dataset_num_ptr} -plans_name nnUNetMovedPlans + + +echo "-------------------------------------------------------" +echo "Running pretraining on ${dataset_name_ptr} ..." +echo "-------------------------------------------------------" +# training +CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_num_ptr} ${configuration} all -tr ${nnunet_trainer} + +echo "-------------------------------------------------------" +echo "Running inference on ${dataset_name_ptr} ..." +echo "-------------------------------------------------------" +# running inference on the source dataset +CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ptr}/imagesTs_${dataset_git_annex_name} -tr ${nnunet_trainer} -o ${nnUNet_results}/${dataset_name_ptr}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_all/test -d ${dataset_num_ptr} -f all -c ${configuration} + + +echo "-------------------------------------------------------" +echo "Pretraining done, Running finetuning on ${dataset_name_ftu} ..." +echo "-------------------------------------------------------" +path_ptr_weights=${nnUNet_results}/${dataset_name_ptr}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_all/checkpoint_best.pth + + +for fold in ${folds[@]}; do + echo "-------------------------------------------" + echo "Training/Finetuning on Fold $fold" + echo "-------------------------------------------" + + # training + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_name_ftu} ${configuration} ${fold} -tr ${nnunet_trainer} -pretrained_weights ${path_ptr_weights} + + echo "" + echo "-------------------------------------------" + echo "Training completed, Testing on Fold $fold" + echo "-------------------------------------------" + + # run inference on testing sets for each site + for site in ${sites[@]}; do + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} + + echo "-------------------------------------------------------" + echo "Running ANIMA evaluation on Test set for ${site} " + echo "-------------------------------------------------------" + + python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} + + done + +done \ No newline at end of file From a7ebe177e9fe0f8c3b7fe5cc6455c1aa4318a547 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Fri, 1 Mar 2024 17:15:42 -0500 Subject: [PATCH 04/15] add script for creating MSD datalists for SCI and DCM datasets --- monai/create_msd_data.py | 189 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 monai/create_msd_data.py diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py new file mode 100644 index 0000000..3c5acc2 --- /dev/null +++ b/monai/create_msd_data.py @@ -0,0 +1,189 @@ +import os +import re +import json +from tqdm import tqdm +import yaml +import argparse +from pathlib import Path +from loguru import logger +from sklearn.model_selection import train_test_split +from utils import get_git_branch_and_commit + + +def get_parser(): + parser = argparse.ArgumentParser(description='Code for MSD-style JSON datalist for DCM and SCI lesions dataset.') + + parser.add_argument('--path-data', nargs='+', required=True, type=str, help='Path to BIDS dataset(s) (list).') + parser.add_argument('--path-out', type=str, required=True, + help='Path to the output directory where dataset json is saved') + parser.add_argument('--split', nargs='+', type=float, default=[0.7, 0.2, 0.1], + help='Ratios of training, validation and test splits lying between 0-1. ' + 'Example: --split 0.7 0.2 0.1') + parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") + parser.add_argument('--pathology', default='dcm', type=str, required=True, + help="Type of pathology in the dataset(s). Default: 'dcm' (for dcm-zurich-lesions). " + "Options: 'sci' (for sci lesions) ") + + return parser + + +def find_site_in_path(path, pathology='dcm'): + """Extracts site identifier from the given path. + + Args: + path (str): Input path containing a site identifier. + + Returns: + str: Extracted site identifier or None if not found. + """ + if pathology == 'dcm': + # Find 'dcm-zurich-lesions' or 'dcm-zurich-lesions-20231115' + match = re.search(r'dcm-zurich-lesions(-\d{8})?', path) + return match.group(0) if match else None + elif pathology == 'sci': + # Find 'sci-zurich', 'sci-colorado', or 'sci-paris' + match = re.search(r'sci-(zurich|colorado|paris)', path) + return match.group(0) if match else None + + +def main(): + args = get_parser().parse_args() + + train_ratio, val_ratio, test_ratio = args.split + seed = args.seed + root = args.path_data + pathology = args.pathology + if pathology == 'dcm': + lesion_fname_suffix = 'label-lesion' + sc_fname_suffix = 'label-SC_mask-manual' + datalist_fname = f"dataset_dcm_lesions_seed{seed}" + elif pathology == 'sci': + lesion_fname_suffix = 'T2w_lesion-manual' + sc_fname_suffix = 'T2w_seg-manual' + datalist_fname = f"dataset_sci_lesions_seed{seed}" + + # Check if dataset paths exist + for path in args.path_data: + if not os.path.exists(path): + raise ValueError(f"Path {path} does not exist.") + + # Get sites from the input paths + sites = set(find_site_in_path(path, pathology) for path in args.path_data if find_site_in_path(path, pathology)) + + all_subjects, train_images, val_images, test_images = [], {}, {}, {} + # temp dict for storing dataset commits + dataset_commits = {} + + # loop over the datasets + for idx, dataset in enumerate(args.path_data, start=1): + root = Path(dataset) + + # get the git branch and commit ID of the dataset + dataset_name = os.path.basename(os.path.normpath(dataset)) + branch, commit = get_git_branch_and_commit(dataset) + dataset_commits[dataset_name] = f"git-{branch}-{commit}" + + # get recursively all the subjects from the root folder + subjects = [sub for sub in os.listdir(root) if sub.startswith("sub-")] + + # add to the list of all subjects + all_subjects.extend(subjects) + + # Get the training and test splits + tr_subs, te_subs = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) + if "sci-paris" in dataset: + # add all test subjects to the to the training set + tr_subs.extend(te_subs) + te_subs = [] + tr_subs, val_subs = train_test_split(tr_subs, test_size=val_ratio / (train_ratio + val_ratio), random_state=args.seed) + + # recurively find the lesion files for training and test subjects) + tr_lesion_files = [str(path) for sub in tr_subs for path in Path(root).rglob(f"{sub}_*{lesion_fname_suffix}.nii.gz")] + val_lesion_files = [str(path) for sub in val_subs for path in Path(root).rglob(f"{sub}_*{lesion_fname_suffix}.nii.gz")] + te_lesion_files = [str(path) for sub in te_subs for path in Path(root).rglob(f"{sub}_*{lesion_fname_suffix}.nii.gz")] + + # update the train and test images dicts with the key as the subject and value as the path to the subject + train_images.update({sub: os.path.join(root, sub) for sub in tr_lesion_files}) + val_images.update({sub: os.path.join(root, sub) for sub in val_lesion_files}) + test_images.update({ + f"site_{idx}": {sub: os.path.join(root, sub) for sub in te_lesion_files} + }) + # test_images.update({sub: os.path.join(root, sub) for sub in te_subs}) + + # remove empty test sites + test_images = {k: v for k, v in test_images.items() if v} + + logger.info(f"Found subjects in the training set (combining all datasets): {len(train_images)}") + logger.info(f"Found subjects in the validation set (combining all datasets): {len(val_images)}") + logger.info(f"Found subjects in the test set (combining all datasets): {len([sub for site in test_images.values() for sub in site])}") + + # # dump train/val/test splits into a yaml file + # with open(f"data_split_{contrast}_{args.label_type}_seed{seed}.yaml", 'w') as file: + # yaml.dump({'train': train_subjects, 'val': val_subjects, 'test': test_subjects}, file, indent=2, sort_keys=True) + + # keys to be defined in the dataset_0.json + params = {} + params["description"] = "spine-generic-uncropped" + params["labels"] = { + "0": "background", + "1": "soft-sc-seg" + } + params["license"] = "nk" + params["modality"] = { + "0": "MRI" + } + params["name"] = "spine-generic" + params["numTest"] = len([sub for site in test_images.values() for sub in site]) + params["numTraining"] = len(train_images) + params["numValidation"] = len(val_images) + params["seed"] = args.seed + params["reference"] = "University of Zurich" + params["tensorImageSize"] = "3D" + + train_images_dict = {"train": train_images} + val_images_dict = {"validation": val_images} + test_images_dict = {} + for site, images in test_images.items(): + temp_dict = {f"test_{site}": images} + test_images_dict.update(temp_dict) + + all_images_list = [train_images_dict, val_images_dict, test_images_dict] + + for images_dict in tqdm(all_images_list, desc="Iterating through train/val/test splits"): + + for name, images_list in images_dict.items(): + + temp_list = [] + for subject_no, image in enumerate(images_list): + + temp_data_t2w = {} + if pathology == 'dcm': + temp_data_t2w["image"] = image.replace('/derivatives/labels', '').replace(f'_{lesion_fname_suffix}', '') + temp_data_t2w["label-sc"] = image.replace(f'_{lesion_fname_suffix}', f'_{sc_fname_suffix}') + elif pathology == 'sci': + temp_data_t2w["image"] = image.replace('/derivatives/labels', '').replace(f'_{lesion_fname_suffix}', '_T2w') + temp_data_t2w["label-sc"] = image.replace(f'_{lesion_fname_suffix}', f'_{sc_fname_suffix}') + + temp_data_t2w["label-lesion"] = image + + if os.path.exists(temp_data_t2w["label-lesion"]) and os.path.exists(temp_data_t2w["label-sc"]) and os.path.exists(temp_data_t2w["image"]): + temp_list.append(temp_data_t2w) + else: + logger.info(f"Either Image/SC-Seg/Lesion-seg does not exist.") + + params[name] = temp_list + logger.info(f"Number of images in {name} set: {len(temp_list)}") + + final_json = json.dumps(params, indent=4, sort_keys=True) + if not os.path.exists(args.path_out): + os.makedirs(args.path_out, exist_ok=True) + + jsonFile = open(args.path_out + "/" + f"{datalist_fname}.json", "w") + jsonFile.write(final_json) + jsonFile.close() + + +if __name__ == "__main__": + main() + + From ddccf09374db0b17d4efe1035469a4c453637ec2 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Fri, 1 Mar 2024 17:16:18 -0500 Subject: [PATCH 05/15] add utils for creating MSD datalists --- monai/utils.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 monai/utils.py diff --git a/monai/utils.py b/monai/utils.py new file mode 100644 index 0000000..798d711 --- /dev/null +++ b/monai/utils.py @@ -0,0 +1,46 @@ +import subprocess + + +def get_git_branch_and_commit(dataset_path=None): + """ + :return: git branch and commit ID, with trailing '*' if modified + Taken from: https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/utils/sys.py#L476 + and https://github.com/spinalcordtoolbox/spinalcordtoolbox/blob/master/spinalcordtoolbox/utils/sys.py#L461 + """ + + # branch info + b = subprocess.Popen(["git", "rev-parse", "--abbrev-ref", "HEAD"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, cwd=dataset_path) + b_output, _ = b.communicate() + b_status = b.returncode + + if b_status == 0: + branch = b_output.decode().strip() + else: + branch = "!?!" + + # commit info + p = subprocess.Popen(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=dataset_path) + output, _ = p.communicate() + status = p.returncode + if status == 0: + commit = output.decode().strip() + else: + commit = "?!?" + + p = subprocess.Popen(["git", "status", "--porcelain"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=dataset_path) + output, _ = p.communicate() + status = p.returncode + if status == 0: + unclean = True + for line in output.decode().strip().splitlines(): + line = line.rstrip() + if line.startswith("??"): # ignore ignored files, they can't hurt + continue + break + else: + unclean = False + if unclean: + commit += "*" + + return branch, commit From 60c4e00453780e0a4f18670a5cfb3b9d72692860 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Fri, 1 Mar 2024 17:17:47 -0500 Subject: [PATCH 06/15] add init version of models --- monai/models.py | 168 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 monai/models.py diff --git a/monai/models.py b/monai/models.py new file mode 100644 index 0000000..7c6d88d --- /dev/null +++ b/monai/models.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger + +# ---------------------------- Imports for nnUNet's Model ----------------------------- +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 + + +# ====================================================================================================== +# Define plans json taken from nnUNet +# ====================================================================================================== +nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": 32, + "n_conv_per_stage_encoder": [2, 2, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [2, 2, 2, 2, 2], + "pool_op_kernel_sizes": [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": 320, +} + + +# ====================================================================================================== +# Utils for nnUNet's Model +# ==================================================================================================== +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +# ====================================================================================================== +# Define the network based on plans json +# ==================================================================================================== +def create_nnunet_from_plans( + plans, + num_input_channels: int, + num_classes: int, + deep_supervision: bool = True, + initialization: str = "scratch", + pretrained_checkpoint: str = None + ): + """ + Adapted from nnUNet's source code: + https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/utilities/get_network_from_plans.py#L9 + + """ + num_stages = len(plans["conv_kernel_sizes"]) + + dim = len(plans["conv_kernel_sizes"][0]) + conv_op = convert_dim_to_conv_op(dim) + + segmentation_network_class_name = plans["UNet_class_name"] + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': plans["n_conv_per_stage_encoder"], + 'n_conv_per_stage_decoder': plans["n_conv_per_stage_decoder"] + } + + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(plans["UNet_base_num_features"] * 2 ** i, + plans["unet_max_num_features"]) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=plans["conv_kernel_sizes"], + strides=plans["pool_op_kernel_sizes"], + num_classes=num_classes, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + + if initialization == "scratch": + logger.info("Initializing weights from scratch ...") + model.apply(InitWeights_He(1e-2)) + + elif initialization == "pretrained": + assert pretrained_checkpoint is not None, "Please provide the path to the pretrained checkpoint." + logger.info(f"Loading weights from {pretrained_checkpoint} ...") + pretrained_weights = torch.load(pretrained_checkpoint) + + # remove the segmentation layers (the 1x1(x1) layers that produce the segmentation maps) + # identified by keys with '.seg_layers') + pretrained_weights = {k: v for k, v in pretrained_weights.items() if '.seg_layers' not in k} + model.load_state_dict(pretrained_weights) + + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + + return model + + + +if __name__ == "__main__": + + enable_deep_supervision = True + model = create_nnunet_from_plans(nnunet_plans, 1, 1, enable_deep_supervision, "scratch") + input = torch.randn(1, 1, 64, 192, 320) + # M1: using encoder and decoder separately + skips = model.encoder(input) + print(skips[-1].shape) + output = model.decoder(skips) + # # M2: using the whole model directly + # output = model(input) + # if enable_deep_supervision: + # for i in range(len(output)): + # print(output[i].shape) + # else: + # print(output.shape) + + # # save the model + # torch.save(model.state_dict(), "model.pth") + + # # load the model + # model = create_nnunet_from_plans(nnunet_plans, 1, 1, enable_deep_supervision, "pretrained", "model.pth") + + # print(output.shape) \ No newline at end of file From 9093e3fa89c27980e562bfcb2024cb918391523a Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 5 Mar 2024 12:21:29 -0500 Subject: [PATCH 07/15] add unified script for sci pretraining and dcm finetuning; lesions --- .../run_sci_pretraining_and_dcm_finetuning.sh | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 nnunet/run_sci_pretraining_and_dcm_finetuning.sh diff --git a/nnunet/run_sci_pretraining_and_dcm_finetuning.sh b/nnunet/run_sci_pretraining_and_dcm_finetuning.sh new file mode 100644 index 0000000..f1db5cd --- /dev/null +++ b/nnunet/run_sci_pretraining_and_dcm_finetuning.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# +# This script combines pre-training on dcm-zurich for compression detection and +# fine-tuning on dcm-zurich-lesions for lesion segmentation. +# +# Assumes that the datasets for pretraining and finetuning already exist. +# These are most likely the outputs of `convert_bids_to_nnUNetv2*.py` scripts. +# +# Usage: +# cd ~/code/model-seg-dcm +# ./nnunet/run_dcm_zurich_pretraining_and_finetuning.sh +# +# Author: Jan Valosek, Naga Karthik +# + +# Uncomment for full verbose +# set -x + +# Immediately exit if error +set -e -o pipefail + +# Exit if user presses CTRL+C (Linux) or CMD+C (OSX) +trap "echo Caught Keyboard Interrupt within script. Exiting now.; exit" INT + +# Global variables +cuda_visible_devices=1 +folds=(1) +sites=(dcm-zurich-lesions dcm-zurich-lesions-20231115) +nnunet_trainer="nnUNetTrainerDiceCELoss_noSmooth" +# nnunet_trainer="nnUNetTrainer_1epoch" # default: nnUNetTrainer; nnUNetTrainer_Xepochs +configuration="3d_fullres" # for 2D training, use "2d" + +# Variables for pretraining on SCI data 3 sites (i.e. source dataset) +dataset_num_ptr="190" +dataset_name_ptr="Dataset${dataset_num_ptr}_tSCI3SitesALSeed710Pretrain" +# dataset_git_annex_name="dcm-zurich" + +# Variables for finetuning on dcm-zurich-lesions (i.e. target dataset) +dataset_num_ftu="192" +dataset_name_ftu="Dataset${dataset_num_ftu}_dcmZurichLesionsFinetune" + + +echo "-------------------------------------------------------" +echo "Running plan_and_preprocess for ${dataset_name_ftu} ... (Finetuning dataset)" +echo "-------------------------------------------------------" +nnUNetv2_plan_and_preprocess -d ${dataset_num_ftu} --verify_dataset_integrity -c ${configuration} + +echo "-------------------------------------------------------" +echo "Running plan_and_preprocess for ${dataset_name_ptr} ... (Pretraining dataset)" +echo "-------------------------------------------------------" +nnUNetv2_plan_and_preprocess -d ${dataset_num_ptr} --verify_dataset_integrity -c ${configuration} + +echo "-------------------------------------------------------" +echo "Extracting dataset fingerprint for ${dataset_name_ptr} ... (Pretraining dataset)" +echo "-------------------------------------------------------" +nnUNetv2_extract_fingerprint -d ${dataset_num_ptr} + +echo "-------------------------------------------------------" +echo "Moving plans from ${dataset_name_ftu} to ${dataset_name_ptr} ... (Finetuning --> Pretraining)" +echo "-------------------------------------------------------" +nnUNetv2_move_plans_between_datasets -s ${dataset_num_ftu} -t ${dataset_num_ptr} -sp nnUNetPlans -tp nnUNetMovedPlans + +echo "-------------------------------------------------------" +echo "Running (only) preprocessing for ${dataset_name_ptr} after moving plans ... (Pretraining dataset)" +echo "-------------------------------------------------------" +nnUNetv2_preprocess -d ${dataset_num_ptr} -plans_name nnUNetMovedPlans + + +echo "-------------------------------------------------------" +echo "Running pretraining on ${dataset_name_ptr} ..." +echo "-------------------------------------------------------" +# training +CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_num_ptr} ${configuration} all -tr ${nnunet_trainer} -p nnUNetMovedPlans + +echo "-------------------------------------------------------" +echo "Running inference on ${dataset_name_ptr} ..." +echo "-------------------------------------------------------" +# # running inference on the source dataset +# CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ptr}/imagesTs_${dataset_git_annex_name} -tr ${nnunet_trainer} -o ${nnUNet_results}/${dataset_name_ptr}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_all/test -d ${dataset_num_ptr} -f all -c ${configuration} + + +echo "-------------------------------------------------------" +echo "Pretraining done, Running finetuning on ${dataset_name_ftu} ..." +echo "-------------------------------------------------------" +path_ptr_weights=${nnUNet_results}/${dataset_name_ptr}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_all/checkpoint_best.pth + + +for fold in ${folds[@]}; do + echo "-------------------------------------------" + echo "Training/Finetuning on Fold $fold" + echo "-------------------------------------------" + + # training + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_name_ftu} ${configuration} ${fold} -tr ${nnunet_trainer} -pretrained_weights ${path_ptr_weights} + + echo "" + echo "-------------------------------------------" + echo "Training completed, Testing on Fold $fold" + echo "-------------------------------------------" + + # run inference on testing sets for each site + for site in ${sites[@]}; do + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} + + echo "-------------------------------------------------------" + echo "Running ANIMA evaluation on Test set for ${site} " + echo "-------------------------------------------------------" + + python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} + + done + +done \ No newline at end of file From d526d4f1cbefdedabd97063835c71e617aaa7237 Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 5 Mar 2024 12:34:08 -0500 Subject: [PATCH 08/15] fix bug in pretraining cmd to use moved plans --- nnunet/run_dcm_zurich_pretraining_and_finetuning.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh index ef75076..e580f63 100644 --- a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh +++ b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh @@ -70,7 +70,7 @@ echo "-------------------------------------------------------" echo "Running pretraining on ${dataset_name_ptr} ..." echo "-------------------------------------------------------" # training -CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_num_ptr} ${configuration} all -tr ${nnunet_trainer} +CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_num_ptr} ${configuration} all -tr ${nnunet_trainer} -p nnUNetMovedPlans echo "-------------------------------------------------------" echo "Running inference on ${dataset_name_ptr} ..." From 14f85057a88b3dd9995774d5760f6900e320efdd Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 5 Mar 2024 17:29:15 -0500 Subject: [PATCH 09/15] fix bug in path for finding pretrained weights --- nnunet/run_dcm_zurich_pretraining_and_finetuning.sh | 2 +- nnunet/run_sci_pretraining_and_dcm_finetuning.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh index e580f63..0dfd5be 100644 --- a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh +++ b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh @@ -82,7 +82,7 @@ CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/$ echo "-------------------------------------------------------" echo "Pretraining done, Running finetuning on ${dataset_name_ftu} ..." echo "-------------------------------------------------------" -path_ptr_weights=${nnUNet_results}/${dataset_name_ptr}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_all/checkpoint_best.pth +path_ptr_weights=${nnUNet_results}/${dataset_name_ptr}/${nnunet_trainer}__nnUNetMovedPlans__${configuration}/fold_all/checkpoint_best.pth for fold in ${folds[@]}; do diff --git a/nnunet/run_sci_pretraining_and_dcm_finetuning.sh b/nnunet/run_sci_pretraining_and_dcm_finetuning.sh index f1db5cd..262d3dd 100644 --- a/nnunet/run_sci_pretraining_and_dcm_finetuning.sh +++ b/nnunet/run_sci_pretraining_and_dcm_finetuning.sh @@ -82,7 +82,7 @@ echo "-------------------------------------------------------" echo "-------------------------------------------------------" echo "Pretraining done, Running finetuning on ${dataset_name_ftu} ..." echo "-------------------------------------------------------" -path_ptr_weights=${nnUNet_results}/${dataset_name_ptr}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_all/checkpoint_best.pth +path_ptr_weights=${nnUNet_results}/${dataset_name_ptr}/${nnunet_trainer}__nnUNetMovedPlans__${configuration}/fold_all/checkpoint_best.pth for fold in ${folds[@]}; do From 4c738de091c2b246e16ad83018b597e2c014077a Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 5 Mar 2024 18:06:33 -0500 Subject: [PATCH 10/15] add new trainer variant for finetuning --- nnunet/run_dcm_zurich_pretraining_and_finetuning.sh | 10 +++++++--- nnunet/run_sci_pretraining_and_dcm_finetuning.sh | 12 ++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh index 0dfd5be..c0f6327 100644 --- a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh +++ b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh @@ -30,6 +30,10 @@ nnunet_trainer="nnUNetTrainer" # nnunet_trainer="nnUNetTrainer_2000epochs" # default: nnUNetTrainer configuration="3d_fullres" # for 2D training, use "2d" +# NOTE: after pre-training for 1000 epochs, fine-tuning doesn't need that many epochs +# hence, creating a new variant with less epochs +nnunet_trainer_ftu="nnUNetTrainer_250epochs" + # Variables for pretraining on dcm-zurich (i.e. source dataset) dataset_num_ptr="191" dataset_name_ptr="Dataset${dataset_num_ptr}_dcmZurichPretrain" @@ -91,7 +95,7 @@ for fold in ${folds[@]}; do echo "-------------------------------------------" # training - CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_name_ftu} ${configuration} ${fold} -tr ${nnunet_trainer} -pretrained_weights ${path_ptr_weights} + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_name_ftu} ${configuration} ${fold} -tr ${nnunet_trainer_ftu} -pretrained_weights ${path_ptr_weights} echo "" echo "-------------------------------------------" @@ -100,13 +104,13 @@ for fold in ${folds[@]}; do # run inference on testing sets for each site for site in ${sites[@]}; do - CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer_ftu} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetMovedPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} echo "-------------------------------------------------------" echo "Running ANIMA evaluation on Test set for ${site} " echo "-------------------------------------------------------" - python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} + python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer_ftu}__nnUNetMovedPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} done diff --git a/nnunet/run_sci_pretraining_and_dcm_finetuning.sh b/nnunet/run_sci_pretraining_and_dcm_finetuning.sh index 262d3dd..ca3e9d4 100644 --- a/nnunet/run_sci_pretraining_and_dcm_finetuning.sh +++ b/nnunet/run_sci_pretraining_and_dcm_finetuning.sh @@ -27,9 +27,13 @@ cuda_visible_devices=1 folds=(1) sites=(dcm-zurich-lesions dcm-zurich-lesions-20231115) nnunet_trainer="nnUNetTrainerDiceCELoss_noSmooth" -# nnunet_trainer="nnUNetTrainer_1epoch" # default: nnUNetTrainer; nnUNetTrainer_Xepochs +# nnunet_trainer="nnUNetTrainer" # default: nnUNetTrainer; nnUNetTrainer_Xepochs configuration="3d_fullres" # for 2D training, use "2d" +# NOTE: after pre-training for 1000 epochs, fine-tuning doesn't need that many epochs +# hence, creating a new variant with less epochs +nnunet_trainer_ftu="nnUNetTrainerDiceCELoss_noSmooth_500epochs" + # Variables for pretraining on SCI data 3 sites (i.e. source dataset) dataset_num_ptr="190" dataset_name_ptr="Dataset${dataset_num_ptr}_tSCI3SitesALSeed710Pretrain" @@ -91,7 +95,7 @@ for fold in ${folds[@]}; do echo "-------------------------------------------" # training - CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_name_ftu} ${configuration} ${fold} -tr ${nnunet_trainer} -pretrained_weights ${path_ptr_weights} + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_train ${dataset_name_ftu} ${configuration} ${fold} -tr ${nnunet_trainer_ftu} -pretrained_weights ${path_ptr_weights} echo "" echo "-------------------------------------------" @@ -100,13 +104,13 @@ for fold in ${folds[@]}; do # run inference on testing sets for each site for site in ${sites[@]}; do - CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer_ftu} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetMovedPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} echo "-------------------------------------------------------" echo "Running ANIMA evaluation on Test set for ${site} " echo "-------------------------------------------------------" - python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} + python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer_ftu}__nnUNetMovedPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} done From 7e02d6d3814738b7d53a53656e7f45811faeb19d Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Wed, 6 Mar 2024 11:51:11 -0500 Subject: [PATCH 11/15] fix inference commands after finetuning --- nnunet/run_dcm_zurich_pretraining_and_finetuning.sh | 4 ++-- nnunet/run_sci_pretraining_and_dcm_finetuning.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh index c0f6327..fdaf217 100644 --- a/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh +++ b/nnunet/run_dcm_zurich_pretraining_and_finetuning.sh @@ -104,13 +104,13 @@ for fold in ${folds[@]}; do # run inference on testing sets for each site for site in ${sites[@]}; do - CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer_ftu} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetMovedPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer_ftu} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer_ftu}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} echo "-------------------------------------------------------" echo "Running ANIMA evaluation on Test set for ${site} " echo "-------------------------------------------------------" - python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer_ftu}__nnUNetMovedPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} + python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer_ftu}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} done diff --git a/nnunet/run_sci_pretraining_and_dcm_finetuning.sh b/nnunet/run_sci_pretraining_and_dcm_finetuning.sh index ca3e9d4..79be419 100644 --- a/nnunet/run_sci_pretraining_and_dcm_finetuning.sh +++ b/nnunet/run_sci_pretraining_and_dcm_finetuning.sh @@ -104,13 +104,13 @@ for fold in ${folds[@]}; do # run inference on testing sets for each site for site in ${sites[@]}; do - CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer_ftu} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer}__nnUNetMovedPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} + CUDA_VISIBLE_DEVICES=${cuda_visible_devices} nnUNetv2_predict -i ${nnUNet_raw}/${dataset_name_ftu}/imagesTs_${site} -tr ${nnunet_trainer_ftu} -o ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer_ftu}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} -d ${dataset_num_ftu} -f ${fold} -c ${configuration} echo "-------------------------------------------------------" echo "Running ANIMA evaluation on Test set for ${site} " echo "-------------------------------------------------------" - python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer_ftu}__nnUNetMovedPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} + python testing/compute_anima_metrics.py --pred-folder ${nnUNet_results}/${dataset_name_ftu}/${nnunet_trainer_ftu}__nnUNetPlans__${configuration}/fold_${fold}/test_${site} --gt-folder ${nnUNet_raw}/${dataset_name_ftu}/labelsTs_${site} --dataset-name ${site} done From 666a5c02659617acfd93ac4d31e03edd0a6c41ec Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Thu, 7 Mar 2024 17:26:31 -0500 Subject: [PATCH 12/15] add init version of transforms from contrast-agnostic --- monai/transforms.py | 58 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 monai/transforms.py diff --git a/monai/transforms.py b/monai/transforms.py new file mode 100644 index 0000000..c162add --- /dev/null +++ b/monai/transforms.py @@ -0,0 +1,58 @@ + +import numpy as np +import monai.transforms as transforms + +def train_transforms(crop_size, patch_size, lbl_key="label"): + + monai_transforms = [ + # pre-processing + transforms.LoadImaged(keys=["image", lbl_key]), + transforms.EnsureChannelFirstd(keys=["image", lbl_key]), + # NOTE: spine interpolation with order=2 is spline, order=1 is linear + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), + transforms.ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + transforms.RandCropByPosNegLabeld(keys=["image", lbl_key], label_key="label-lesion", + spatial_size=patch_size, pos=1, neg=1, num_samples=4, + # if num_samples=4, then 4 samples/image are randomly generated + image_key="image", image_threshold=0.), + # data-augmentation + transforms.RandAffined(keys=["image", lbl_key], mode=(2, 1), prob=0.9, + rotate_range=(-20. / 360 * 2. * np.pi, 20. / 360 * 2. * np.pi), # monai expects in radians + scale_range=(-0.2, 0.2), + translate_range=(-0.1, 0.1)), + transforms.Rand3DElasticd(keys=["image", lbl_key], prob=0.5, + sigma_range=(3.5, 5.5), + magnitude_range=(25., 35.)), + transforms.RandSimulateLowResolutiond(keys=["image"], zoom_range=(0.5, 1.0), prob=0.25), + transforms.RandAdjustContrastd(keys=["image"], gamma=(0.5, 3.), prob=0.5), # this is monai's RandomGamma + transforms.RandBiasFieldd(keys=["image"], coeff_range=(0.0, 0.5), degree=3, prob=0.3), + transforms.RandGaussianNoised(keys=["image"], mean=0.0, std=0.1, prob=0.1), + transforms.RandGaussianSmoothd(keys=["image"], sigma_x=(0., 2.), sigma_y=(0., 2.), sigma_z=(0., 2.0), prob=0.3), + transforms.RandScaleIntensityd(keys=["image"], factors=(-0.25, 1), prob=0.15), # this is nnUNet's BrightnessMultiplicativeTransform + transforms.RandFlipd(keys=["image", lbl_key], prob=0.3,), + transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ] + + return transforms.Compose(monai_transforms) + +def inference_transforms(crop_size, lbl_key="label"): + return transforms.Compose([ + transforms.LoadImaged(keys=["image", lbl_key], image_only=False), + transforms.EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), + transforms.Orientationd(keys=["image", lbl_key], axcodes="RPI"), + transforms.Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + transforms.ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + transforms.DivisiblePadd(keys=["image", lbl_key], k=2**5), # pad inputs to ensure divisibility by no. of layers nnUNet has (5) + transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) + +def val_transforms(crop_size, lbl_key="label"): + return transforms.Compose([ + transforms.LoadImaged(keys=["image", lbl_key], image_only=False), + transforms.EnsureChannelFirstd(keys=["image", lbl_key]), + # CropForegroundd(keys=["image", lbl_key], source_key="image"), + transforms.Spacingd(keys=["image", lbl_key], pixdim=(1.0, 1.0, 1.0), mode=(2, 1)), # mode=("bilinear", "bilinear"),), + transforms.ResizeWithPadOrCropd(keys=["image", lbl_key], spatial_size=crop_size,), + transforms.NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=False), + ]) \ No newline at end of file From 946c72ffd1d71f6af740efd60d29cc3fe1ecf660 Mon Sep 17 00:00:00 2001 From: valosekj Date: Sat, 9 Mar 2024 08:20:58 -0500 Subject: [PATCH 13/15] fix typo --- monai/create_msd_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index 3c5acc2..ca13f40 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -92,7 +92,7 @@ def main(): # Get the training and test splits tr_subs, te_subs = train_test_split(subjects, test_size=test_ratio, random_state=args.seed) if "sci-paris" in dataset: - # add all test subjects to the to the training set + # add all test subjects to the training set tr_subs.extend(te_subs) te_subs = [] tr_subs, val_subs = train_test_split(tr_subs, test_size=val_ratio / (train_ratio + val_ratio), random_state=args.seed) From bcc972583425b49a7546cf0d2f2c10099d4a4b47 Mon Sep 17 00:00:00 2001 From: valosekj Date: Sat, 9 Mar 2024 08:21:54 -0500 Subject: [PATCH 14/15] Change train/val/test splits to 0.6 0.2 0.2 to produce the same testing subjects as https://github.com/ivadomed/model-seg-dcm/blob/bfe1d8f0d794705b53b275a15a5d88f06901ef69/dataset_conversion/convert_bids_to_nnUNetv2_region-based.py and https://github.com/ivadomed/model-seg-dcm/blob/bfe1d8f0d794705b53b275a15a5d88f06901ef69/dataset_conversion/convert_bids_to_nnUNetv2_multi-channel.py --- monai/create_msd_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index ca13f40..3ca9e2c 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -16,9 +16,9 @@ def get_parser(): parser.add_argument('--path-data', nargs='+', required=True, type=str, help='Path to BIDS dataset(s) (list).') parser.add_argument('--path-out', type=str, required=True, help='Path to the output directory where dataset json is saved') - parser.add_argument('--split', nargs='+', type=float, default=[0.7, 0.2, 0.1], + parser.add_argument('--split', nargs='+', type=float, default=[0.6, 0.2, 0.2], help='Ratios of training, validation and test splits lying between 0-1. ' - 'Example: --split 0.7 0.2 0.1') + 'Example: --split 0.6 0.2 0.2') parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility") parser.add_argument('--pathology', default='dcm', type=str, required=True, help="Type of pathology in the dataset(s). Default: 'dcm' (for dcm-zurich-lesions). " From d3b2740e68094b23e75c27db71a3c4151b75f60d Mon Sep 17 00:00:00 2001 From: Naga Karthik Date: Tue, 19 Mar 2024 15:48:46 -0400 Subject: [PATCH 15/15] add stuff --- configs/train.yaml | 72 +++++ monai/losses.py | 142 +++++++++ monai/main.py | 752 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 966 insertions(+) create mode 100644 configs/train.yaml create mode 100644 monai/losses.py create mode 100644 monai/main.py diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000..2279a09 --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,72 @@ +seed: 15 +save_test_preds: True + +directories: + # Path to the saved models directory + models_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models/followup + # Path to the saved results directory + results_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/results/models_followup + # Path to the saved wandb logs directory + # if None, starts training from scratch. Otherwise, resumes training from the specified wandb run folder + wandb_run_folder: None + +dataset: + # Dataset name (will be used as "group_name" for wandb logging) + name: spine-generic + # Path to the dataset directory containing all datalists (.json files) + root_dir: /home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/datalists/spine-generic/seed15 + # Type of contrast to be used for training. "all" corresponds to training on all contrasts + contrast: all # choices: ["t1w", "t2w", "t2star", "mton", "mtoff", "dwi", "all"] + # Type of label to be used for training. + label_type: soft_bin # choices: ["hard", "soft", "soft_bin"] + +preprocessing: + # Online resampling of images to the specified spacing. + spacing: [1.0, 1.0, 1.0] + # Center crop/pad images to the specified size. (NOTE: done after resampling) + # values correspond to R-L, A-P, I-S axes of the image after 1mm isotropic resampling. + crop_pad_size: [96, 256, 448] + patch_size: [32, 64, 112] + +opt: + name: adam + lr: 0.001 + max_epochs: 200 + batch_size: 2 + # Interval between validation checks in epochs + check_val_every_n_epochs: 5 + # Early stopping patience (this is until patience * check_val_every_n_epochs) + early_stopping_patience: 20 + + +model: + # Model architecture to be used for training (also to be specified as args in the command line) + nnunet: + # NOTE: these info are typically taken from nnUNetPlans.json (if an nnUNet model is trained) + base_num_features: 32 + max_num_features: 320 + n_conv_per_stage_encoder: [2, 2, 2, 2, 2, 2] + n_conv_per_stage_decoder: [2, 2, 2, 2, 2] + pool_op_kernel_sizes: [ + [1, 1, 1], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + [1, 2, 2] + ] + enable_deep_supervision: True + + mednext: + num_input_channels: 1 + base_num_features: 32 + num_classes: 1 + kernel_size: 3 # 3x3x3 and 5x5x5 were tested in publication + block_counts: [2,2,2,2,1,1,1,1,1] # number of blocks in each layer + enable_deep_supervision: True + + swinunetr: + spatial_dims: 3 + depths: [2, 2, 2, 2] + num_heads: [3, 6, 12, 24] # number of heads in multi-head Attention + feature_size: 36 \ No newline at end of file diff --git a/monai/losses.py b/monai/losses.py new file mode 100644 index 0000000..29304a2 --- /dev/null +++ b/monai/losses.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import scipy +import numpy as np + + +# TODO: also check out nnUNet's implementation of soft-dice loss (if required) +# https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/training/loss/dice.py + +class SoftDiceLoss(nn.Module): + ''' + soft-dice loss, useful in binary segmentation + taken from: https://github.com/CoinCheung/pytorch-loss/blob/master/soft_dice_loss.py + ''' + def __init__(self, p=1, smooth=1): + super(SoftDiceLoss, self).__init__() + self.p = p + self.smooth = smooth + + def forward(self, logits, labels): + ''' + inputs: + preds: logits - tensor of shape (N, H, W, ...) + labels: soft labels [0,1] - tensor of shape(N, H, W, ...) + output: + loss: tensor of shape(1, ) + ''' + preds = F.relu(logits) / F.relu(logits).max() if bool(F.relu(logits).max()) else F.relu(logits) + + numer = (preds * labels).sum() + denor = (preds.pow(self.p) + labels.pow(self.p)).sum() + # loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth) + loss = - (2 * numer + self.smooth) / (denor + self.smooth) + return loss + + +class DiceCrossEntropyLoss(nn.Module): + def __init__(self, weight_ce=1.0, weight_dice=1.0): + super(DiceCrossEntropyLoss, self).__init__() + self.ce_weight = weight_ce + self.dice_weight = weight_dice + + self.dice_loss = SoftDiceLoss() + # self.ce_loss = RobustCrossEntropyLoss() + self.ce_loss = nn.CrossEntropyLoss() + + def forward(self, preds, labels): + ''' + inputs: + preds: logits (not probabilities!) - tensor of shape (N, H, W, ...) + labels: soft labels [0,1] - tensor of shape(N, H, W, ...) + output: + loss: tensor of shape(1, ) + ''' + ce_loss = self.ce_loss(preds, labels) + + # dice loss will convert logits to probabilities + dice_loss = self.dice_loss(preds, labels) + + loss = self.ce_weight * ce_loss + self.dice_weight * dice_loss + return loss + + +class AdapWingLoss(nn.Module): + """ + Adaptive Wing loss used for heatmap regression + Adapted from: https://github.com/ivadomed/ivadomed/blob/master/ivadomed/losses.py#L341 + + .. seealso:: + Wang, Xinyao, Liefeng Bo, and Li Fuxin. "Adaptive wing loss for robust face alignment via heatmap regression." + Proceedings of the IEEE International Conference on Computer Vision. 2019. + + Args: + theta (float): Threshold to switch between the linear and non-linear parts of the piece-wise loss function. + alpha (float): Used to adapt the behaviour of the loss function at y=0 and y=1 and make loss smooth at 0 (background). + It needs to be slightly above 2 to maintain ideal properties. + omega (float): Multiplicative factor for non linear part of the loss. + epsilon (float): factor to avoid gradient explosion. It must not be too small + NOTE: Larger omega and smaller epsilon values will increase the influence on small errors and vice versa + """ + + def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1, reduction='sum'): + self.theta = theta + self.alpha = alpha + self.omega = omega + self.epsilon = epsilon + self.reduction = reduction + super(AdapWingLoss, self).__init__() + + def forward(self, input, target): + eps = self.epsilon + batch_size = target.size()[0] + + # Adaptive Wing loss. Section 4.2 of the paper. + # Compute adaptive factor + A = self.omega * (1 / (1 + torch.pow(self.theta / eps, + self.alpha - target))) * \ + (self.alpha - target) * torch.pow(self.theta / eps, + self.alpha - target - 1) * (1 / eps) + + # Constant term to link linear and non linear part + C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / eps, self.alpha - target))) + + diff_hm = torch.abs(target - input) + AWingLoss = A * diff_hm - C + idx = diff_hm < self.theta + # NOTE: this is a memory-efficient version than the one in ivadomed losses.py + # where idx is True, compute the non-linear part of the loss, otherwise keep the linear part + # the non-linear parts ensures small errors (as given by idx) have a larger influence to refine the predictions at the boundaries + # the linear part makes the loss function behave more like the MSE loss, which has a linear influence + # (i.e. small errors where y=0 --> small influence --> small gradients) + AWingLoss = torch.where(idx, self.omega * torch.log(1 + torch.pow(diff_hm / eps, self.alpha - target)), AWingLoss) + + + # Mask for weighting the loss function. Section 4.3 of the paper. + mask = torch.zeros_like(target) + kernel = scipy.ndimage.generate_binary_structure(2, 2) + # For 3D segmentation tasks + if len(input.shape) == 5: + kernel = scipy.ndimage.generate_binary_structure(3, 2) + + for i in range(batch_size): + img_list = list() + img_list.append(np.round(target[i].cpu().numpy() * 255)) + img_merge = np.concatenate(img_list) + img_dilate = scipy.ndimage.binary_opening(img_merge, np.expand_dims(kernel, axis=0)) + # NOTE: why 51? the paper thresholds the dilated GT heatmap at 0.2. So, 51/255 = 0.2 + img_dilate[img_dilate < 51] = 1 # 0*omega+1 + img_dilate[img_dilate >= 51] = 1 + self.omega # 1*omega+1 + img_dilate = np.array(img_dilate, dtype=int) + + mask[i] = torch.tensor(img_dilate) + + AWingLoss *= mask + + sum_loss = torch.sum(AWingLoss) + if self.reduction == "sum": + return sum_loss + elif self.reduction == "mean": + all_pixel = torch.sum(mask) + return sum_loss / all_pixel \ No newline at end of file diff --git a/monai/main.py b/monai/main.py new file mode 100644 index 0000000..e08ad2c --- /dev/null +++ b/monai/main.py @@ -0,0 +1,752 @@ +import os +import argparse +from datetime import datetime +from loguru import logger +import yaml + +import numpy as np +import wandb +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import matplotlib.pyplot as plt + +from utils import dice_score, PolyLRScheduler, plot_slices, check_empty_patch +from losses import AdapWingLoss +from transforms import train_transforms, val_transforms +from models import create_nnunet_from_plans + +from monai.utils import set_determinism +from monai.inferers import sliding_window_inference +from monai.networks.nets import SwinUNETR +from monai.data import (DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch) +from monai.transforms import (Compose, EnsureType, EnsureTyped, Invertd, SaveImage) + +# mednext +from nnunet_mednext import MedNeXt + +def get_args(): + parser = argparse.ArgumentParser(description='Script for (pre)training lesion segmentation model for DCM and SCI patients.') + + # arguments for model + parser.add_argument('-m', '--model', choices=['swinunetr', 'nnunet', 'mednext', 'swinunetr'], + default='nnunet', type=str, + help='Model type to be used. Currently only supports nnUNet.') + # path to the config file + parser.add_argument("--config", type=str, default="./config.json", + help="Path to the config file containing all training details.") + # saving + parser.add_argument('--debug', default=False, action='store_true', help='if true, results are not logged to wandb') + parser.add_argument('-c', '--continue_from_checkpoint', default=False, action='store_true', + help='Load model from checkpoint and continue training') + args = parser.parse_args() + + return args + + +# create a "model"-agnostic class with PL to use different models +class Model(pl.LightningModule): + def __init__(self, config, data_root, net, loss_function, optimizer_class, exp_id=None, results_path=None): + super().__init__() + self.cfg = config + self.save_hyperparameters(ignore=['net', 'loss_function']) + + self.root = data_root + self.net = net + self.lr = config["opt"]["lr"] + self.loss_function = loss_function + self.optimizer_class = optimizer_class + self.save_exp_id = exp_id + self.results_path = results_path + + self.best_val_dice, self.best_val_epoch = 0, 0 + self.best_val_loss = float("inf") + + # define cropping and padding dimensions + # NOTE about patch sizes: nnUNet defines patches using the median size of the dataset as the reference + # BUT, for SC images, this means a lot of context outside the spinal cord is included in the patches + # which could be sub-optimal. + # On the other hand, ivadomed used a patch-size that's heavily cropped along the R-L direction so that + # only the SC is in context. + self.spacing = config["preprocessing"]["spacing"] + self.voxel_cropping_size = self.inference_roi_size = config["preprocessing"]["crop_pad_size"] + self.patch_size = config["preprocessing"]["patch_size"] + + # define post-processing transforms for validation, nothing fancy just making sure that it's a tensor (default) + self.val_post_pred = self.val_post_label = Compose([EnsureType()]) + + # define evaluation metric + self.soft_dice_metric = dice_score + + # temp lists for storing outputs from training, validation, and testing + self.train_step_outputs = [] + self.val_step_outputs = [] + self.test_step_outputs = [] + + + # -------------------------------- + # FORWARD PASS + # -------------------------------- + def forward(self, x): + + out = self.net(x) + # # NOTE: MONAI's models only output the logits, not the output after the final activation function + # # https://docs.monai.io/en/0.9.0/_modules/monai/networks/nets/unetr.html#UNETR.forward refers to the + # # UnetOutBlock (https://docs.monai.io/en/0.9.0/_modules/monai/networks/blocks/dynunet_block.html#UnetOutBlock) + # # as the final block applied to the input, which is just a convolutional layer with no activation function + # # Hence, we are used Normalized ReLU to normalize the logits to the final output + # normalized_out = F.relu(out) / F.relu(out).max() if bool(F.relu(out).max()) else F.relu(out) + + return out # returns logits + + + # -------------------------------- + # DATA PREPARATION + # -------------------------------- + def prepare_data(self): + # set deterministic training for reproducibility + set_determinism(seed=self.cfg["seed"]) + + # define training and validation transforms + transforms_train = train_transforms( + crop_size=self.voxel_cropping_size, + patch_size=self.patch_size, + lbl_key='label' + ) + transforms_val = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') + + # load the dataset + logger.info(f"Training with {self.cfg['dataset']['label_type']} labels ...") + dataset = os.path.join(self.root, + f"dataset_{self.cfg['dataset']['contrast']}_{self.cfg['dataset']['label_type']}_seed{self.cfg['seed']}.json" + ) + logger.info(f"Loading dataset: {dataset}") + train_files = load_decathlon_datalist(dataset, True, "train") + val_files = load_decathlon_datalist(dataset, True, "validation") + test_files = load_decathlon_datalist(dataset, True, "test") + + if args.debug: + train_files = train_files[:15] + val_files = val_files[:15] + test_files = test_files[:6] + + train_cache_rate = 0.25 if args.debug else 0.5 + self.train_ds = CacheDataset(data=train_files, transform=transforms_train, cache_rate=train_cache_rate, num_workers=4) + self.val_ds = CacheDataset(data=val_files, transform=transforms_val, cache_rate=0.25, num_workers=4) + + # define test transforms + transforms_test = val_transforms(crop_size=self.inference_roi_size, lbl_key='label') + + # define post-processing transforms for testing; taken (with explanations) from + # https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_inference_dict.py#L66 + self.test_post_pred = Compose([ + EnsureTyped(keys=["pred", "label"]), + Invertd(keys=["pred", "label"], transform=transforms_test, + orig_keys=["image", "label"], + meta_keys=["pred_meta_dict", "label_meta_dict"], + nearest_interp=False, to_tensor=True), + ]) + self.test_ds = CacheDataset(data=test_files, transform=transforms_test, cache_rate=0.1, num_workers=4) + + + # -------------------------------- + # DATA LOADERS + # -------------------------------- + def train_dataloader(self): + return DataLoader(self.train_ds, batch_size=self.cfg["opt"]["batch_size"], shuffle=True, num_workers=16, + pin_memory=True, persistent_workers=True) + + def val_dataloader(self): + return DataLoader(self.val_ds, batch_size=1, shuffle=False, num_workers=16, pin_memory=True, + persistent_workers=True) + + def test_dataloader(self): + return DataLoader(self.test_ds, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) + + + # -------------------------------- + # OPTIMIZATION + # -------------------------------- + def configure_optimizers(self): + if self.cfg["opt"]["name"] == "sgd": + optimizer = self.optimizer_class(self.parameters(), lr=self.lr, momentum=0.99, weight_decay=3e-5, nesterov=True) + else: + optimizer = self.optimizer_class(self.parameters(), lr=self.lr) + # scheduler = PolyLRScheduler(optimizer, self.lr, max_steps=self.args.max_epochs) + # NOTE: ivadomed using CosineAnnealingLR with T_max = 200 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.cfg["opt"]["max_epochs"]) + return [optimizer], [scheduler] + + + # -------------------------------- + # TRAINING + # -------------------------------- + def training_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # check if any label image patch is empty in the batch + if check_empty_patch(labels) is None: + # print(f"Empty label patch found. Skipping training step ...") + return None + + output = self.forward(inputs) # logits + # print(f"labels.shape: {labels.shape} \t output.shape: {output.shape}") + + if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + + # calculate dice loss for each output + loss, train_soft_dice = 0.0, 0.0 + for i in range(len(output)): + # give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + # NOTE: outputs[0] is the final pred, outputs[-1] is the lowest resolution pred (at the bottleneck) + # we're downsampling the GT to the resolution of each deepsupervision feature map output + # (instead of upsampling each deepsupervision feature map output to the final resolution) + downsampled_gt = F.interpolate(labels, size=output[i].shape[-3:], mode='trilinear', align_corners=False) + # print(f"downsampled_gt.shape: {downsampled_gt.shape} \t output[i].shape: {output[i].shape}") + loss += (0.5 ** i) * self.loss_function(output[i], downsampled_gt) + + # get probabilities from logits + out = F.relu(output[i]) / F.relu(output[i]).max() if bool(F.relu(output[i]).max()) else F.relu(output[i]) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice += self.soft_dice_metric(out, downsampled_gt) + + # average dice loss across the outputs + loss /= len(output) + train_soft_dice /= len(output) + + else: + # calculate training loss + loss = self.loss_function(output, labels) + + # get probabilities from logits + output = F.relu(output) / F.relu(output).max() if bool(F.relu(output).max()) else F.relu(output) + + # calculate train dice + # NOTE: this is done on patches (and not entire 3D volume) because SlidingWindowInference is not used here + # So, take this dice score with a lot of salt + train_soft_dice = self.soft_dice_metric(output, labels) + + metrics_dict = { + "loss": loss.cpu(), + "train_soft_dice": train_soft_dice.detach().cpu(), + "train_number": len(inputs), + # "train_image": inputs[0].detach().cpu().squeeze(), + # "train_gt": labels[0].detach().cpu().squeeze(), + # "train_pred": output[0].detach().cpu().squeeze() + } + self.train_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_train_epoch_end(self): + + if self.train_step_outputs == []: + # means the training step was skipped because of empty input patch + return None + else: + train_loss, train_soft_dice = 0, 0 + num_items = len(self.train_step_outputs) + for output in self.train_step_outputs: + train_loss += output["loss"].item() + train_soft_dice += output["train_soft_dice"].item() + + mean_train_loss = (train_loss / num_items) + mean_train_soft_dice = (train_soft_dice / num_items) + + wandb_logs = { + "train_soft_dice": mean_train_soft_dice, + "train_loss": mean_train_loss, + } + self.log_dict(wandb_logs) + + # # plot the training images + # fig = plot_slices(image=self.train_step_outputs[0]["train_image"], + # gt=self.train_step_outputs[0]["train_gt"], + # pred=self.train_step_outputs[0]["train_pred"], + # debug=args.debug) + # wandb.log({"training images": wandb.Image(fig)}) + + # free up memory + self.train_step_outputs.clear() + wandb_logs.clear() + # plt.close(fig) + + + # -------------------------------- + # VALIDATION + # -------------------------------- + def validation_step(self, batch, batch_idx): + + inputs, labels = batch["image"], batch["label"] + + # NOTE: this calculates the loss on the entire image after sliding window + outputs = sliding_window_inference(inputs, self.inference_roi_size, mode="gaussian", + sw_batch_size=4, predictor=self.forward, overlap=0.5,) + # outputs shape: (B, C, ) + if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + # we only need the output with the highest resolution + outputs = outputs[0] + + # calculate validation loss + loss = self.loss_function(outputs, labels) + + # get probabilities from logits + outputs = F.relu(outputs) / F.relu(outputs).max() if bool(F.relu(outputs).max()) else F.relu(outputs) + + # post-process for calculating the evaluation metric + post_outputs = [self.val_post_pred(i) for i in decollate_batch(outputs)] + post_labels = [self.val_post_label(i) for i in decollate_batch(labels)] + val_soft_dice = self.soft_dice_metric(post_outputs[0], post_labels[0]) + + hard_preds, hard_labels = (post_outputs[0].detach() > 0.5).float(), (post_labels[0].detach() > 0.5).float() + val_hard_dice = self.soft_dice_metric(hard_preds, hard_labels) + + # NOTE: there was a massive memory leak when storing cuda tensors in this dict. Hence, + # using .detach() to avoid storing the whole computation graph + # Ref: https://discuss.pytorch.org/t/cuda-memory-leak-while-training/82855/2 + metrics_dict = { + "val_loss": loss.detach().cpu(), + "val_soft_dice": val_soft_dice.detach().cpu(), + "val_hard_dice": val_hard_dice.detach().cpu(), + "val_number": len(post_outputs), + # "val_image": inputs[0].detach().cpu().squeeze(), + # "val_gt": labels[0].detach().cpu().squeeze(), + # "val_pred": post_outputs[0].detach().cpu().squeeze(), + } + self.val_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_validation_epoch_end(self): + + val_loss, num_items, val_soft_dice, val_hard_dice = 0, 0, 0, 0 + for output in self.val_step_outputs: + val_loss += output["val_loss"].sum().item() + val_soft_dice += output["val_soft_dice"].sum().item() + val_hard_dice += output["val_hard_dice"].sum().item() + num_items += output["val_number"] + + mean_val_loss = (val_loss / num_items) + mean_val_soft_dice = (val_soft_dice / num_items) + mean_val_hard_dice = (val_hard_dice / num_items) + + wandb_logs = { + "val_soft_dice": mean_val_soft_dice, + "val_hard_dice": mean_val_hard_dice, + "val_loss": mean_val_loss, + } + # save the best model based on validation dice score + if mean_val_soft_dice > self.best_val_dice: + self.best_val_dice = mean_val_soft_dice + self.best_val_epoch = self.current_epoch + + # save the best model based on validation CSA loss + if mean_val_loss < self.best_val_loss: + self.best_val_loss = mean_val_loss + self.best_val_epoch = self.current_epoch + + logger.info( + f"\nCurrent epoch: {self.current_epoch}" + f"\nAverage Soft Dice (VAL): {mean_val_soft_dice:.4f}" + f"\nAverage Hard Dice (VAL): {mean_val_hard_dice:.4f}" + f"\nAverage AdapWing Loss (VAL): {mean_val_loss:.4f}" + f"\nBest Average AdapWing Loss: {self.best_val_loss:.4f} at Epoch: {self.best_val_epoch}" + f"\n----------------------------------------------------") + + + # log on to wandb + self.log_dict(wandb_logs) + + # # plot the validation images + # fig = plot_slices(image=self.val_step_outputs[0]["val_image"], + # gt=self.val_step_outputs[0]["val_gt"], + # pred=self.val_step_outputs[0]["val_pred"],) + # wandb.log({"validation images": wandb.Image(fig)}) + + # free up memory + self.val_step_outputs.clear() + wandb_logs.clear() + # plt.close(fig) + + # return {"log": wandb_logs} + + # -------------------------------- + # TESTING + # -------------------------------- + def test_step(self, batch, batch_idx): + + test_input = batch["image"] + # print(batch["label_meta_dict"]["filename_or_obj"][0]) + batch["pred"] = sliding_window_inference(test_input, self.inference_roi_size, + sw_batch_size=4, predictor=self.forward, overlap=0.5) + + if args.model in ["nnunet", "mednext"] and self.cfg['model'][args.model]["enable_deep_supervision"]: + # we only need the output with the highest resolution + batch["pred"] = batch["pred"][0] + + # normalize the logits + batch["pred"] = F.relu(batch["pred"]) / F.relu(batch["pred"]).max() if bool(F.relu(batch["pred"]).max()) else F.relu(batch["pred"]) + + post_test_out = [self.test_post_pred(i) for i in decollate_batch(batch)] + + # make sure that the shapes of prediction and GT label are the same + # print(f"pred shape: {post_test_out[0]['pred'].shape}, label shape: {post_test_out[0]['label'].shape}") + assert post_test_out[0]['pred'].shape == post_test_out[0]['label'].shape + + pred, label = post_test_out[0]['pred'].cpu(), post_test_out[0]['label'].cpu() + + # save the prediction and label + if self.cfg["save_test_preds"]: + + subject_name = (batch["image_meta_dict"]["filename_or_obj"][0]).split("/")[-1].replace(".nii.gz", "") + logger.info(f"Saving subject: {subject_name}") + + # image saver class + save_folder = os.path.join(self.results_path, subject_name.split("_")[0]) + pred_saver = SaveImage( + output_dir=save_folder, output_postfix="pred", output_ext=".nii.gz", + separate_folder=False, print_log=False, resample=True) + # save the prediction + pred_saver(pred) + + # label_saver = SaveImage( + # output_dir=save_folder, output_postfix="gt", output_ext=".nii.gz", + # separate_folder=False, print_log=False, resample=True) + # # save the label + # label_saver(label) + + + # NOTE: Important point from the SoftSeg paper - binarize predictions before computing metrics + # calculate soft and hard dice here (for quick overview), other metrics can be computed from + # the saved predictions using ANIMA + # 1. Dice Score + test_soft_dice = self.soft_dice_metric(pred, label) + + # binarizing the predictions + pred = (post_test_out[0]['pred'].detach().cpu() > 0.5).float() + label = (post_test_out[0]['label'].detach().cpu() > 0.5).float() + + # 1.1 Hard Dice Score + test_hard_dice = self.soft_dice_metric(pred.numpy(), label.numpy()) + + metrics_dict = { + "test_hard_dice": test_hard_dice, + "test_soft_dice": test_soft_dice, + } + self.test_step_outputs.append(metrics_dict) + + return metrics_dict + + def on_test_epoch_end(self): + + avg_hard_dice_test, std_hard_dice_test = np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_hard_dice"] for x in self.test_step_outputs]).std() + avg_soft_dice_test, std_soft_dice_test = np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).mean(), \ + np.stack([x["test_soft_dice"] for x in self.test_step_outputs]).std() + + logger.info(f"Test (Soft) Dice: {avg_soft_dice_test}") + logger.info(f"Test (Hard) Dice: {avg_hard_dice_test}") + + self.avg_test_dice, self.std_test_dice = avg_soft_dice_test, std_soft_dice_test + self.avg_test_dice_hard, self.std_test_dice_hard = avg_hard_dice_test, std_hard_dice_test + + # free up memory + self.test_step_outputs.clear() + + +# -------------------------------- +# MAIN +# -------------------------------- +def main(args): + + # load config file + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + # Setting the seed + pl.seed_everything(config["seed"], workers=True) + + # define root path for finding datalists + dataset_root = config["dataset"]["root_dir"] + + # define optimizer + if config["opt"]["name"] == "adam": + optimizer_class = torch.optim.Adam + elif config["opt"]["name"] == "sgd": + optimizer_class = torch.optim.SGD + + # define models + if args.model in ["swinunetr"]: + # # define image size to be fed to the model + + # define model + net = SwinUNETR(spatial_dims=config["model"]["swinunetr"]["spatial_dims"], + in_channels=1, out_channels=1, + img_size=config["preprocessing"]["crop_pad_size"], + depths=config["model"]["swinunetr"]["depths"], + feature_size=config["model"]["swinunetr"]["feature_size"], + num_heads=config["model"]["swinunetr"]["num_heads"], + ) + # img_size = f"{img_size[0]}x{img_size[1]}x{img_size[2]}" + patch_size = f"{config['preprocessing']['crop_pad_size'][0]}x" \ + f"{config['preprocessing']['crop_pad_size'][1]}x" \ + f"{config['preprocessing']['crop_pad_size'][2]}" + + save_exp_id = f"{args.model}_seed={config['seed']}_" \ + f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ + f"d={config['model']['swinunetr']['depths'][0]}_" \ + f"nf={config['model']['swinunetr']['feature_size']}_" \ + f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ + f"bs={config['opt']['batch_size']}_{patch_size}" \ + # save_exp_id = f"_CSAdiceL_nspv={args.num_samples_per_volume}_bs={args.batch_size}_{img_size}" \ + + elif args.model == "nnunet": + + if config["model"]["nnunet"]["enable_deep_supervision"]: + logger.info(f"Using nnUNet model WITH deep supervision ...") + else: + logger.info(f"Using nnUNet model WITHOUT deep supervision ...") + + logger.info("Defining plans for nnUNet model ...") + # ========================================================================================= + # Define plans json taken from nnUNet_preprocessed folder + # ========================================================================================= + nnunet_plans = { + "UNet_class_name": "PlainConvUNet", + "UNet_base_num_features": config["model"]["nnunet"]["base_num_features"], + "n_conv_per_stage_encoder": config["model"]["nnunet"]["n_conv_per_stage_encoder"], + "n_conv_per_stage_decoder": config["model"]["nnunet"]["n_conv_per_stage_decoder"], + "pool_op_kernel_sizes": config["model"]["nnunet"]["pool_op_kernel_sizes"], + "conv_kernel_sizes": [ + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3], + [3, 3, 3] + ], + "unet_max_num_features": config["model"]["nnunet"]["max_num_features"], + } + + # define model + net = create_nnunet_from_plans(plans=nnunet_plans, num_input_channels=1, num_classes=1, + deep_supervision=config["model"]["nnunet"]["enable_deep_supervision"]) + # variable for saving patch size in the experiment id (same as crop_pad_size) + patch_size = f"{config['preprocessing']['crop_pad_size'][0]}x" \ + f"{config['preprocessing']['crop_pad_size'][1]}x" \ + f"{config['preprocessing']['crop_pad_size'][2]}" + # save experiment id + save_exp_id = f"{args.model}_seed={config['seed']}_" \ + f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ + f"nf={config['model']['nnunet']['base_num_features']}_" \ + f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ + f"bs={config['opt']['batch_size']}_{patch_size}" \ + + if args.debug: + save_exp_id = f"DEBUG_{save_exp_id}" + + elif args.model == "mednext": + # NOTE: the S, B models in the paper don't fit as-is for this data, gpu + # hence tweaking the models + logger.info(f"Using MedNext model tweaked ...") + net = MedNeXt( + in_channels=config["model"]["mednext"]["num_input_channels"], + n_channels=config["model"]["mednext"]["base_num_features"], + n_classes=config["model"]["mednext"]["num_classes"], + exp_r=2, + kernel_size=config["model"]["mednext"]["kernel_size"], + deep_supervision=config["model"]["mednext"]["enable_deep_supervision"], + do_res=True, + do_res_up_down=True, + checkpoint_style="outside_block", + block_counts=config["model"]["mednext"]["block_counts"], + ) + + # variable for saving patch size in the experiment id (same as crop_pad_size) + patch_size = f"{config['preprocessing']['crop_pad_size'][0]}x" \ + f"{config['preprocessing']['crop_pad_size'][1]}x" \ + f"{config['preprocessing']['crop_pad_size'][2]}" + # count number of 2s in the block_counts list + num_two_blocks = config["model"]["mednext"]["block_counts"].count(2) + # save experiment id + save_exp_id = f"{args.model}_seed={config['seed']}_" \ + f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ + f"nf={config['model']['mednext']['base_num_features']}_bcs={num_two_blocks}_" \ + f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ + f"bs={config['opt']['batch_size']}_{patch_size}" \ + + if args.debug: + save_exp_id = f"DEBUG_{save_exp_id}" + + timestamp = datetime.now().strftime(f"%Y%m%d-%H%M") # prints in YYYYMMDD-HHMMSS format + save_exp_id = f"{save_exp_id}_{timestamp}" + + # save output to a log file + logger.add(os.path.join(config["directories"]["models_dir"], f"{save_exp_id}", "logs.txt"), rotation="10 MB", level="INFO") + + # save config file to the output folder + with open(os.path.join(config["directories"]["models_dir"], f"{save_exp_id}", "config.yaml"), "w") as f: + yaml.dump(config, f) + + # define loss function + loss_func = AdapWingLoss(theta=0.5, omega=8, alpha=2.1, epsilon=1, reduction="sum") + # NOTE: tried increasing omega and decreasing epsilon but results marginally worse than the above + # loss_func = AdapWingLoss(theta=0.5, omega=12, alpha=2.1, epsilon=0.5, reduction="sum") + logger.info(f"Using AdapWingLoss with theta={loss_func.theta}, omega={loss_func.omega}, alpha={loss_func.alpha}, epsilon={loss_func.epsilon} ...") + + # define callbacks + early_stopping = pl.callbacks.EarlyStopping( + monitor="val_loss", min_delta=0.00, + patience=config["opt"]["early_stopping_patience"], + verbose=False, mode="min") + + lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch') + + # training from scratch + if not args.continue_from_checkpoint: + # to save the best model on validation + save_path = os.path.join(config["directories"]["models_dir"], f"{save_exp_id}") + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + # to save the results/model predictions + results_path = os.path.join(config["directories"]["results_dir"], f"{save_exp_id}") + if not os.path.exists(results_path): + os.makedirs(results_path, exist_ok=True) + + # i.e. train by loading weights from scratch + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # saving the best model based on validation loss + logger.info(f"Saving best model to {save_path}!") + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_path, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=False) + + # # saving the best model based on soft validation dice score + # checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + # dirpath=save_path, filename='best_model_dice', monitor='val_soft_dice', + # save_top_k=1, mode="max", save_last=False, save_weights_only=True) + + logger.info(f"Starting training from scratch ...") + # wandb logger + exp_logger = pl.loggers.WandbLogger( + name=save_exp_id, + save_dir="/home/GRAMES.POLYMTL.CA/u114716/contrast-agnostic/saved_models", + group=config["dataset"]["name"], + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=config) + + # Saving training script to wandb + wandb.save("main.py") + wandb.save("transforms.py") + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["opt"]["check_val_every_n_epochs"], + max_epochs=config["opt"]["max_epochs"], + precision=32, + # deterministic=True, + enable_progress_bar=False) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model) + logger.info(f" Training Done!") + + else: + logger.info(f" Resuming training from the latest checkpoint! ") + + # check if wandb run folder is provided to resume using the same run + if config["directories"]["wandb_run_folder"] is None: + raise ValueError("Please provide the wandb run folder to resume training using the same run on WandB!") + else: + wandb_run_folder = os.path.basename(config["directories"]["wandb_run_folder"]) + wandb_run_id = wandb_run_folder.split("-")[-1] + + save_exp_id = config["directories"]["models_dir"] + save_path = os.path.dirname(config["directories"]["models_dir"]) + logger.info(f"save_path: {save_path}") + results_path = config["directories"]["results_dir"] + + # i.e. train by loading existing weights + pl_model = Model(config, data_root=dataset_root, + optimizer_class=optimizer_class, loss_function=loss_func, net=net, + exp_id=save_exp_id, results_path=results_path) + + # saving the best model based on validation CSA loss + checkpoint_callback_loss = pl.callbacks.ModelCheckpoint( + dirpath=save_exp_id, filename='best_model', monitor='val_loss', + save_top_k=1, mode="min", save_last=True, save_weights_only=True) + + # # saving the best model based on soft validation dice score + # checkpoint_callback_dice = pl.callbacks.ModelCheckpoint( + # dirpath=save_exp_id, filename='best_model_dice', monitor='val_soft_dice', + # save_top_k=1, mode="max", save_last=False, save_weights_only=True) + + # wandb logger + exp_logger = pl.loggers.WandbLogger( + save_dir=save_path, + group=config["dataset"]["name"], + log_model=True, # save best model using checkpoint callback + project='contrast-agnostic', + entity='naga-karthik', + config=args, + id=wandb_run_id, resume='must') + + # initialise Lightning's trainer. + trainer = pl.Trainer( + devices=1, accelerator="gpu", + logger=exp_logger, + callbacks=[checkpoint_callback_loss, lr_monitor, early_stopping], + check_val_every_n_epoch=config["opt"]["check_val_every_n_epochs"], + max_epochs=config["opt"]["max_epochs"], + precision=32, + enable_progress_bar=True) + # profiler="simple",) # to profile the training time taken for each step + + # Train! + trainer.fit(pl_model, ckpt_path=os.path.join(save_exp_id, "last.ckpt"),) + logger.info(f" Training Done!") + + # Test! + trainer.test(pl_model) + logger.info(f"TESTING DONE!") + + # closing the current wandb instance so that a new one is created for the next fold + wandb.finish() + + # TODO: Figure out saving test metrics to a file + with open(os.path.join(results_path, 'test_metrics.txt'), 'a') as f: + print('\n-------------- Test Metrics ----------------', file=f) + print(f"{args.model}_seed={config['seed']}_" \ + f"{config['dataset']['contrast']}_{config['dataset']['label_type']}_" \ + f"nf={config['model']['nnunet']['base_num_features']}_" \ + f"opt={config['opt']['name']}_lr={config['opt']['lr']}_AdapW_" \ + f"bs={config['opt']['batch_size']}_{patch_size}" \ + f"_{timestamp}", file=f) + + print('\n-------------- Test Hard Dice Scores ----------------', file=f) + print("Hard Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice_hard, pl_model.std_test_dice_hard), file=f) + + print('\n-------------- Test Soft Dice Scores ----------------', file=f) + print("Soft Dice --> Mean: %0.3f, Std: %0.3f" % (pl_model.avg_test_dice, pl_model.std_test_dice), file=f) + + print('-------------------------------------------------------', file=f) + + +if __name__ == "__main__": + args = get_args() + main(args) \ No newline at end of file