Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monai training of an MS lesion segmentation model #12

Draft
wants to merge 108 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
bbde8f2
created script to build msd dataset for monai nnunet model training
plbenveniste Mar 11, 2024
081276a
added requirements script
plbenveniste Mar 11, 2024
0ef3f0a
removed yaml from requirements
plbenveniste Mar 11, 2024
451ca30
changed output file names
plbenveniste Mar 12, 2024
1dbd553
added missing requirements
plbenveniste Mar 12, 2024
91bac32
initialised config.yml file example
plbenveniste Mar 12, 2024
48f109c
initialised main file
plbenveniste Mar 12, 2024
5aa2eae
initialised models file
plbenveniste Mar 12, 2024
5b233d2
initalised transforms file
plbenveniste Mar 12, 2024
794430a
modified to fit our training parameters
plbenveniste Mar 12, 2024
0619980
simplied main training script
plbenveniste Mar 12, 2024
e04c423
fixed canproco problem (having both img and label with same link)
plbenveniste Mar 13, 2024
26629c8
added training script for a monai trained UNETR
plbenveniste Mar 13, 2024
e815e14
added fake config file (for debugging)
plbenveniste Mar 13, 2024
5aaf6db
updated requirements to add monai[all] for problem with data loading
plbenveniste Mar 13, 2024
8eafeff
removed old config file
plbenveniste Mar 13, 2024
8f43eac
changed links to dataset
plbenveniste Mar 13, 2024
86e1a32
fixed training. Still need to fix validation params
plbenveniste Mar 13, 2024
32ca1cd
removed files using pytorch ligthning training
plbenveniste Mar 13, 2024
5647955
working monai script based on Jan's code: but no dice score improvement
plbenveniste Mar 14, 2024
a4d7958
pytorch lightning script based on Naga's work
plbenveniste Mar 14, 2024
4e768e4
modified requirements for pytorch lightning
plbenveniste Mar 14, 2024
186c602
added multiply by -1 transform
plbenveniste Mar 18, 2024
adaf04f
parameters changed for config file for first inference run
plbenveniste Mar 18, 2024
dab7dcc
added SoftDiceLoss
plbenveniste Mar 28, 2024
80494e5
removed print from inverse function in utils
plbenveniste Mar 28, 2024
4727dbc
changed resolution to 0.6 isotropic
plbenveniste Mar 28, 2024
5992732
added plot images function for wandb
plbenveniste Mar 28, 2024
9aefb6d
changed loss function and added image printing
plbenveniste Mar 28, 2024
855e26f
changed some parameters for training
plbenveniste Apr 1, 2024
e1ed6e2
added the image plot function
plbenveniste Apr 1, 2024
cffa941
changed model parameters for training
plbenveniste Apr 1, 2024
fc37ca0
code reviewed with no prob but output still problematic
plbenveniste Apr 1, 2024
9f1effd
added lines to save images before training
plbenveniste Apr 1, 2024
01c0912
correction: removed intensity normalisation for labels
plbenveniste Apr 1, 2024
6aa8225
fixed filename to save
plbenveniste Apr 1, 2024
a89b5d7
updated to add some data aug but then removed :/
plbenveniste Apr 2, 2024
6342be0
created file for unet training with multiple input channels
plbenveniste Apr 3, 2024
b858fbc
created file for unet training with multiple output channels
plbenveniste Apr 3, 2024
afb7b4f
training script cleaned for ms lesion seg
plbenveniste Apr 3, 2024
9d6af9b
moved all files in monai and removed nnunet folder
plbenveniste Apr 4, 2024
97b495f
renamed config_fake.yml to config.yml
plbenveniste Apr 4, 2024
1df0305
removed useless previous training script train_monai_UNETR.py
plbenveniste Apr 4, 2024
27eb185
script for training unet with finetuning data-aug parameters
plbenveniste Apr 5, 2024
4d1996f
modified training script with new data augmentation strategies
plbenveniste Apr 5, 2024
a28d176
fixed typos in script and arranged in functions
plbenveniste Apr 5, 2024
a9c293f
fixed parameters for model training on entire dataset
plbenveniste Apr 5, 2024
29b8b69
added function to cound lesions and get total volume
plbenveniste Apr 5, 2024
fc18318
changed batch-size to 8
plbenveniste Apr 5, 2024
b1563f8
changed to attentionUnet
plbenveniste Apr 5, 2024
8e8ee70
added lesion only dataset on entirety of images
plbenveniste Apr 5, 2024
d9997a6
added crop foreground for model training
plbenveniste Apr 10, 2024
6f35a4a
added more data augmentation
plbenveniste Apr 10, 2024
a3df668
added precision and recall metric
plbenveniste Apr 10, 2024
1c3c557
added precision and recall metric : function must be reviewed (not su…
plbenveniste Apr 11, 2024
9d80931
modified version of precision/recall metric
plbenveniste Apr 16, 2024
60a63e6
modified config file
plbenveniste Apr 16, 2024
318c117
modified code to include swinUNETR model
plbenveniste Apr 16, 2024
188ee48
fixed wandb and config file for cleaner pipeline
plbenveniste Apr 16, 2024
01bce84
new script to test the dataset
plbenveniste Apr 17, 2024
d96f1e3
config file for testing the dataset
plbenveniste Apr 17, 2024
761d32f
added cupy install for inference
plbenveniste Apr 17, 2024
11abb73
added script for plotting the performance (dice metric) on the data s…
plbenveniste Apr 17, 2024
e4d6088
correct typo in parser
plbenveniste Apr 17, 2024
aa42e94
fixed typo on basel and bavaria data import
plbenveniste Apr 17, 2024
55b0add
changes made for previous run (before ISMRM)
plbenveniste Jun 4, 2024
149fa34
add function to not take files which are in canproco/exclude.yml
plbenveniste Jun 4, 2024
f28bfe4
added lesion wide metrics
plbenveniste Jun 4, 2024
9ac26f3
changed for bavaria dataset new format
plbenveniste Jun 4, 2024
dadc8ac
added function to remove small objects in utils
plbenveniste Jun 4, 2024
3965b9b
added remove small objects for train, val and inference
plbenveniste Jun 4, 2024
71faa52
changed the min volume threshold
plbenveniste Jun 4, 2024
baaa980
changed msd dataset creation for nih and updated bavaria unstiched data
plbenveniste Jun 26, 2024
500de22
updated requirements and added loguru
plbenveniste Jun 26, 2024
8096913
corrected lesion mask name for nih
plbenveniste Jun 26, 2024
3c7e488
corrected requirements
plbenveniste Jun 26, 2024
7cea4c8
config file for training on ETS server
plbenveniste Jul 15, 2024
5fa1ac1
added nnUNet data augmentation
plbenveniste Jul 16, 2024
88b6619
added contrast, site and orientation in msd dataset
plbenveniste Jul 22, 2024
f387067
improved computation of orientation of image
plbenveniste Jul 22, 2024
ec256ab
added __init__.py file for import possibility
plbenveniste Jul 22, 2024
554b522
removed unused files
plbenveniste Jul 22, 2024
1fdbdd3
moved files to utils folder
plbenveniste Jul 22, 2024
8292ff8
updated parameters for model testing
plbenveniste Jul 23, 2024
7e361ba
updated inference script and evaluation plots scripts
plbenveniste Jul 23, 2024
bf07262
added removal of .nii.gz for UINT1 contrast
plbenveniste Jul 23, 2024
9f5587a
changed workers to 0 for test_model
plbenveniste Aug 1, 2024
c23d63f
added more info in output
plbenveniste Aug 1, 2024
034d97e
created file for cropping aroung head
plbenveniste Aug 1, 2024
00ec30c
updated training script to sota model training script (set workers to 0)
plbenveniste Aug 9, 2024
e038c67
changed location of saving of yaml file to save with the same date as…
plbenveniste Aug 9, 2024
391193c
init mednext training script
plbenveniste Aug 30, 2024
25e7874
added library for diffusion model
plbenveniste Sep 4, 2024
b7ee720
first draft (non-functionning) of diffusion model training script
plbenveniste Sep 4, 2024
6f807f9
created script to train a mednext model
plbenveniste Sep 4, 2024
74b6eed
removed cropping of image before inference
plbenveniste Sep 4, 2024
c497c4e
added new config files
plbenveniste Sep 4, 2024
825d816
added script to perform inference and compute the dice score with var…
plbenveniste Sep 4, 2024
ba6d90e
fixed .cpu problem and added more thresholds
plbenveniste Sep 4, 2024
9027f32
first draft of script for TTA
plbenveniste Sep 4, 2024
11eec85
added code to plot the opt threshold output
plbenveniste Sep 5, 2024
38c1594
fixed threshold to 0.5
plbenveniste Sep 5, 2024
e3b1eff
fixed parenthesis when computing dice score
plbenveniste Sep 11, 2024
a37c6a0
added script to compute TTA with 2nd strategy
plbenveniste Sep 11, 2024
60506b4
added script for mednext inference
plbenveniste Sep 12, 2024
a79cb44
added computation of f1-score, ppv and sensitivity
plbenveniste Sep 25, 2024
953de3f
fixed utils command
plbenveniste Oct 31, 2024
15d13b4
removed dataset aggregation scripts
plbenveniste Dec 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions monai/average_tta_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
This file is used to get all the dice_scores_X.txt files in a directory and average them.

Input:
- Path to the directory containing the dice_scores_X.txt files

Output:
None

Example:
python average_tta_performance.py --pred-dir-path /path/to/dice_scores

Author: Pierre-Louis Benveniste
"""

import os
import argparse
import numpy as np
import pandas as pd
from pathlib import Path


def get_parser():
"""
This function returns the parser for the command line arguments.
"""
parser = argparse.ArgumentParser(description="Average the performance of the model")
parser.add_argument("--pred-dir-path", help="Path to the directory containing the dice_scores_X.txt files", required=True)
return parser


def main():
"""
This function is used to average the performance of the model on the test set.

Args:
None

Returns:
None
"""
# Get the parser
parser = get_parser()
args = parser.parse_args()

# Path to the dice_scores
path_to_outputs = args.pred_dir_path

# Get all the dice_scores_X.txt files using rglob
dice_score_files = [str(file) for file in Path(path_to_outputs).rglob("dice_scores_*.txt")]

# Dict to store the dice scores
dice_scores = {}

# Loop over the dice_scores_X.txt files
for dice_score_file in dice_score_files:
# Open dice results (they are txt files)
with open(os.path.join(path_to_outputs, dice_score_file), 'r') as file:
for line in file:
key, value = line.strip().split(':')
if key in dice_scores:
dice_scores[key].append(float(value))
else:
dice_scores[key] = [float(value)]

# Average the dice scores ang get standard deviation
std = {}
for key in dice_scores:
std[key] = np.std(dice_scores[key])
dice_scores[key] = np.mean(dice_scores[key])

# Save the averaged dice scores
with open(os.path.join(path_to_outputs, "dice_scores.txt"), 'w') as file:
for key in dice_scores:
file.write(f"{key}: {dice_scores[key]}\n")

# Save the standard deviation
with open(os.path.join(path_to_outputs, "std.txt"), 'w') as file:
for key in std:
file.write(f"{key}: {std[key]}\n")


if __name__ == "__main__":
main()
130 changes: 130 additions & 0 deletions monai/compute_performance_tta_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
This script is used to sum all the image predictions of the same subject, then threshold to 0.5 and then compute the dice score.

Input:
--path-pred: Path to the directory containing the predictions
--path-json: Path to the json file containing the data split
--split: Data split to use (train, validation, test)
--output-dir: Output directory to save the dice scores

Output:
None

Example:
python compute_performance_tta_sum.py --path-pred /path/to/predictions --path-json /path/to/data.json --split test --output-dir /path/to/output

Author: Pierre-Louis Benveniste
"""

import os
import numpy as np
import argparse
from pathlib import Path
import json
import nibabel as nib
from tqdm import tqdm


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--path-pred", type=str, required=True, help="Path to the directory containing the predictions")
parser.add_argument("--path-json", type=str, required=True, help="Path to the json file containing the data split")
parser.add_argument("--split", type=str, required=True, help="Data split to use (train, validation, test)")
parser.add_argument("--output-dir", type=str, required=True, help="Output directory to save the dice scores")
return parser.parse_args()


def dice_score(prediction, groundtruth, smooth=1.):
numer = (prediction * groundtruth).sum()
denor = (prediction + groundtruth).sum()
# loss = (2 * numer + self.smooth) / (denor + self.smooth)
dice = (2 * numer + smooth) / (denor + smooth)
return dice


def main():

# Parse arguments
args = parse_args()
path_pred = args.path_pred
path_json = args.path_json
split = args.split
output_dir = args.output_dir

# Create the output directory
if not os.path.exists(output_dir):
os.makedirs(output_dir)

# Get all the predictions (with rglob)
predictions = list(Path(path_pred).rglob("*.nii.gz"))

# List of subjects
subjects = [pred.name for pred in predictions]

n_tta = 10

for i in range(n_tta):
# Remove the _pred_0, _pred_1 ... _pred_9 at the end of the name
subjects = [sub.replace(f"_pred_{i}", "") for sub in subjects]

# Open the conversion dictionary (its a json file)
with open(path_json, "r") as f:
conversion_dict = json.load(f)
conversion_dict = conversion_dict[split]

# Dict of dice score
dice_scores = {}

# Iterate over the subjects in the predictions
for subject in subjects:
print(f"Processing subject {subject}")

# Get all predictions corresponding to the subject
subject_predictions = [str(pred) for pred in predictions if subject.replace(".nii.gz", "") in pred.name]
# print(subject_predictions)

# Find the corresponding label from the conversion dict

image_dict = [data for data in conversion_dict if subject in data["image"]]
label = image_dict[0]["label"]
image = image_dict[0]["image"]

# We now sum all the predictions
summed_prediction = None
for pred in subject_predictions:
pred_data = nib.load(pred).get_fdata()
if summed_prediction is None:
summed_prediction = pred_data
else:
summed_prediction += pred_data

# Threshold the summed prediction
summed_prediction[summed_prediction >= 0.5] = 1
summed_prediction[summed_prediction < 0.5] = 0

# Load the label
label_data = nib.load(label).get_fdata()

# Compute dice score
dice = dice_score(summed_prediction, label_data)
# print(f"Dice score for summed prediction: {dice}")

# Compare the dice score with the individual predictions
for pred in subject_predictions:
pred_data = nib.load(pred).get_fdata()
dice_pred = dice_score(pred_data, label_data)
# print(f"Dice score for {pred}: {dice_pred}")

# Save the dice score
dice_scores[image] = dice

# Save the results
with open(os.path.join(output_dir, "dice_scores.txt"), "w") as f:
for key, value in dice_scores.items():
f.write(f"{key}: {value}\n")

return None


if __name__ == "__main__":
main()
50 changes: 50 additions & 0 deletions monai/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Description: Configuration file for the UNETR model

# Path to the data json file
# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake.json
# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_lesion_sc.json
# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_10_each.json
# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/fake_sc.json
# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-03-13_seed42_canproco.json
# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json
# data: /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json
data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-06-26_seed42_lesionOnly.json
# data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-08-13_seed42_lesionOnly.json
# data: /home/plbenveniste/net/ms-lesion-agnostic/msd_data/fake.json

# Resampling resolution
# pixdim : [1.0, 1.0, 1.0]
pixdim : [0.7, 0.7, 0.7]
# pixdim : [0.5, 0.5, 0.5]

# Spatial size of the input data
spatial_size : [64, 128, 128] # RL, AP, IS
batch_size : 4 # smaller batch size lead to better generalization https://arxiv.org/abs/1609.04836 but longer to train

# Augmentation parameters
DA_probability : 0.2

# Optimizer parameters
lr : 0.0001
weight_decay: 0.00001
early_stopping_patience : 50

# Training parameters
max_iterations : 250
eval_num : 2

# Outputs
# output_path : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/
output_path : /home/plbenveniste/net/ms-lesion-agnostic/results/
# output_path : /home/plbenveniste/net/ms-lesion-agnostic/results_cropped_head/

# Seed
seed : 42

# UNET model parameters
unet_channels : [32, 64, 128, 256, 512, 1024]
unet_strides : [2, 2, 2, 2, 2, 2, 2]

# AttentionUnet
attention_unet_channels : [32, 64, 128, 256, 512]
attention_unet_strides : [2, 2, 2, 2, 2]
21 changes: 21 additions & 0 deletions monai/config_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-17_seed42_lesionOnly.json
# dataset : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/data/monai_data/dataset_2024-04-05_seed42_lesionOnly.json
dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-06-26_seed42_lesionOnly.json
# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-08-13_seed42_lesionOnly.json
# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_optThresh.json
# dataset : /home/plbenveniste/net/ms-lesion-agnostic/msd_data/fake.json

pixdim : [0.7, 0.7, 0.7]
spatial_size : [64, 128, 128]
attention_unet_channels : [32, 64, 128, 256, 512]
attention_unet_strides : [2, 2, 2, 2, 2]

# path_to_model : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/best_model.pth/best_model.ckpt
# path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/tta_exp/best_model.pth/best_model.ckpt
path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/best_model.pth/best_model.ckpt
# path_to_model : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-09-02_12:14:28.124188/best_model.pth/best_model.ckpt

# output_dir : /home/GRAMES.POLYMTL.CA/p119007/ms_lesion_agnostic/results/2024-04-21_16:06:04.890513/
# output_dir : /home/plbenveniste/net/ms-lesion-agnostic/tta_exp
output_dir : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-07-18_10:46:21.634514/
# output_dir : /home/plbenveniste/net/ms-lesion-agnostic/results/2024-09-02_12:14:28.124188/
85 changes: 85 additions & 0 deletions monai/plot_optThresh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
This script plots the performance of the model based on the threshold applied to the predictions.

Input:
--path-scores: Path to the directory containing the dice_scores_X.txt files

Output:
None

Example:
python plot_optThresh.py --path-scores /path/to/dice_scores

Author: Pierre-Louis Benveniste
"""

import os
import argparse
import numpy as np
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


def get_parser():
"""
This function returns the parser for the command line arguments.
"""
parser = argparse.ArgumentParser(description="Plot the optimal threshold")
parser.add_argument("--path-scores", help="Path to the directory containing the dice_scores_X.txt files", required=True)
return parser


def main():

# Get the parser
parser = get_parser()
args = parser.parse_args()

# Path to the dice_scores
path_to_outputs = args.path_scores

# Get all the dice_scores_X.txt files using rglob
dice_score_files = [str(file) for file in Path(path_to_outputs).rglob("dice_scores_*.txt")]

# Create a list to store the dataframes
test_dice_results_list = [None] * len(dice_score_files)

# For each file, get the threshold and the dice score
for i, dice_score_file in enumerate(dice_score_files):
test_dice_results = {}
with open(dice_score_file, 'r') as file:
for line in file:
key, value = line.strip().split(':')
test_dice_results[key] = float(value)
# convert to a df with name and dice score
test_dice_results_list[i] = pd.DataFrame(list(test_dice_results.items()), columns=['name', 'dice_score'])
# Create a column which stores the threshold
test_dice_results_list[i]['threshold'] = str(Path(dice_score_file).name).replace('dice_scores_', '').replace('.txt', '').replace('_', '.')

# Concatenate all the dataframes
test_dice_results = pd.concat(test_dice_results_list)

# Plot
plt.figure(figsize=(20, 10))
plt.grid(True)
sns.violinplot(x='threshold', y='dice_score', data=test_dice_results)
# y ranges from -0.2 to 1.2
plt.ylim(-0.2, 1.2)
plt.title('Dice scores per threshold')
plt.show()

# Save the plot
plt.savefig(path_to_outputs + '/dice_scores_contrast.png')
print(f"Saved the dice_scores plot in {path_to_outputs}")

# Print the average dice score per threshold
for thresh in test_dice_results['threshold'].unique():
print(f"Threshold: {thresh} - Average dice score: {test_dice_results[test_dice_results['threshold'] == thresh]['dice_score'].mean()}")

return None


if __name__ == "__main__":
main()
Loading