Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pre-training on large-scale T2w healthy/pathology data #11

Draft
wants to merge 274 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 205 commits
Commits
Show all changes
274 commits
Select commit Hold shift + click to select a range
f5881aa
change 'num_workers' to 0 to prevent 'RuntimeError: received 0 items …
valosekj Mar 12, 2024
7aa9858
Update input arg description
valosekj Mar 13, 2024
16691d7
Rerun the notebook
valosekj Mar 13, 2024
56ca187
Print hyper-parameters into the log file
valosekj Mar 13, 2024
ce7315a
Add '--cuda' input arg
valosekj Mar 13, 2024
1ebf833
track and save epoch time
valosekj Mar 15, 2024
208b2c1
Plot and save input and output validation images to see how the model…
valosekj Mar 15, 2024
5a9e69a
Add comment
valosekj Mar 15, 2024
eca6616
Do not plot 'outputs_v2' as it is a hidden representation
valosekj Mar 15, 2024
8e55985
Include the epoch number as master title
valosekj Mar 15, 2024
e7575c5
Add 'torch.multiprocessing.set_sharing_strategy('file_system')' to so…
valosekj Mar 15, 2024
3b6df14
Create validation_figures directory if it does not exist
valosekj Mar 15, 2024
59b061c
Use 3 leading zeros for the epoch number in the figures fname
valosekj Mar 15, 2024
f2d648b
Link issue
valosekj Mar 15, 2024
a6bc6ec
Add note for 'RandCropByPosNegLabeld'
valosekj Mar 15, 2024
c1553d6
Add 'number_of_holes' arg to specify the number of holes to be used f…
valosekj Mar 15, 2024
0e802df
typo
valosekj Mar 16, 2024
9255855
batch_size = 8
valosekj Mar 16, 2024
8296913
NUM_WORKERS = batch_size
valosekj Mar 16, 2024
93c5415
number_of_holes=5
valosekj Mar 16, 2024
a4182b5
Update transforms for training of the fine-tuned model
valosekj Mar 16, 2024
2573578
update comment, remove unused imports
valosekj Mar 16, 2024
cfb8f12
Add notebook with RandCoarseDropoutd transform debug
valosekj Mar 16, 2024
7f27340
Fix 'dropout_holes=True' and 'dropout_holes=False' comments
valosekj Mar 17, 2024
74b2f2c
remove unused 'max_spatial_size' arg
valosekj Mar 17, 2024
bdb2ac1
use 'fill_value=0' for 'RandCoarseDropoutd'
valosekj Mar 17, 2024
f7f3689
Plot also RandCoarseDropoutd dropout_holes=False fill_value=0
valosekj Mar 17, 2024
af9c6d0
Add note that the batch size is actually doubled (8*2=16), because we…
valosekj Mar 17, 2024
b416135
Add '--cuda' input arg
valosekj Mar 18, 2024
0303776
Remove 'AsDiscrete'
valosekj Mar 18, 2024
4534c57
Remove 'AsDiscrete'
valosekj Mar 18, 2024
94c8611
Add 'CUDA_NUM=args.cuda'
valosekj Mar 18, 2024
5fb325a
Add TODO to increase batch_size to 16
valosekj Mar 18, 2024
d1a03c8
Use 'roi_size' for 'define_finetune_train_transforms'
valosekj Mar 18, 2024
e3d8086
Use 'label_sc' to crop samples around the SC
valosekj Mar 18, 2024
7e628af
batch_size = 8
valosekj Mar 18, 2024
f832c94
NUM_WORKERS = batch_size
valosekj Mar 18, 2024
c68d65d
Add 'import torch.multiprocessing'
valosekj Mar 18, 2024
a7c7c77
Fix shape logging
valosekj Mar 18, 2024
1c6109c
Change 'img_size' to 'ROI_SIZE'
valosekj Mar 18, 2024
b92e7d5
'batch["label"]' --> 'batch["label_lesion"]'
valosekj Mar 18, 2024
f174fb9
Plot and save input and output validation images to see how the model…
valosekj Mar 18, 2024
532e24b
Fix ROI_SIZE for sliding_window_inference
valosekj Mar 18, 2024
5c10890
Crop samples of 64x64x64 also for Validation of the fine-tuned model
valosekj Mar 18, 2024
d33abfa
update docstring
valosekj Mar 18, 2024
efeccf9
log validation samples shapes
valosekj Mar 18, 2024
f34ece1
Save validation figure only if it contains a lesion
valosekj Mar 18, 2024
e0afafe
Plot GT together with image
valosekj Mar 19, 2024
e588a22
print unique values in the slice to see if it is binary
valosekj Mar 19, 2024
2d0b530
update output fig fname
valosekj Mar 19, 2024
988680e
Add 'AsDiscreted' for Training and Validation of the fine-tuned model
valosekj Mar 19, 2024
f675611
threshold val_labels_list and val_outputs_list by 0.5 threshold befor…
valosekj Mar 19, 2024
cc78896
add normalized relu normalization
naga-karthik Mar 19, 2024
92334ea
fix binarization bug
naga-karthik Mar 19, 2024
2c49ca2
remove 'logger.info(np.unique(output.detach().cpu().numpy()))'
valosekj Mar 19, 2024
fbe3dd7
overlay prediction over input image
valosekj Mar 19, 2024
ba39d28
Fix variable when getting probabilities from logits
valosekj Mar 20, 2024
c8d9a5d
Add debug lines
valosekj Mar 20, 2024
ac040a5
PEP8
valosekj Mar 20, 2024
1e2b61d
Set validation batch_size to 1
valosekj Mar 20, 2024
8f49545
improve comments
valosekj Mar 20, 2024
3a74e28
fix figure title
valosekj Mar 20, 2024
dea121c
comment 'AsDiscreted' transforms
valosekj Mar 20, 2024
013ca65
Make '--pretrained-model' non required to allow training from the scr…
valosekj Mar 21, 2024
f1aae1b
Add script to create spine-generic MSD dataset
valosekj Mar 22, 2024
e4c87be
run notebook again
valosekj Mar 22, 2024
eff9fe8
add loss script containing loss functions
naga-karthik Mar 24, 2024
2eb3ab7
add lr scheduler
naga-karthik Mar 24, 2024
340e824
add init version of transforms
naga-karthik Mar 24, 2024
a3f6f52
add script for dataloading
naga-karthik Mar 24, 2024
6c764b6
add utils for defining models
naga-karthik Mar 24, 2024
9e7b8a1
working init version of model backbones
naga-karthik Mar 24, 2024
ae053ea
untested version of supervised pretraining code
naga-karthik Mar 24, 2024
c26da42
fix Spacingd mode for pretraining
naga-karthik Mar 24, 2024
eadfc47
add utils file
naga-karthik Mar 25, 2024
8c3d33a
fix typo in pretraining transforms
naga-karthik Mar 25, 2024
584aedd
save output to a log file
valosekj Mar 26, 2024
cf95ec2
Add new input arg "--datalists" allowing to specify JSON datalist(s) …
valosekj Mar 26, 2024
ab57eda
Add loop across datalists
valosekj Mar 26, 2024
a9f65cc
Add docstring to load_data.loader
valosekj Mar 26, 2024
cf2076d
Various minor arg parse fixes
valosekj Mar 26, 2024
d9e0f44
PEP8 formatting
valosekj Mar 26, 2024
3037a5f
Add comment to @torch.no_grad()
valosekj Mar 26, 2024
f2af24f
Add logger.info that config is loaded
valosekj Mar 26, 2024
1ac7156
rename "type" to "mode" to avoid conflict with built-in name
valosekj Mar 27, 2024
ddca389
PEP8 formatting
valosekj Mar 27, 2024
478405d
rename 'type' --> 'mode'
valosekj Mar 27, 2024
b62c388
swap the order of val_transforms and inference_transforms for easier …
valosekj Mar 27, 2024
49115b6
remove ununsed imports
valosekj Mar 27, 2024
4287184
Okay, one more renaming :-D
valosekj Mar 27, 2024
0016a38
update loader.load_data docstring
valosekj Mar 27, 2024
305a380
Fix transforms.transforms typo
valosekj Mar 27, 2024
de50aae
Fix 'model_utils' import to 'models.model_utils'
valosekj Mar 27, 2024
d207934
Fix 'evaluate' keyword arg name
valosekj Mar 27, 2024
a362ed7
Simplify 'config["autoencoderkl"]["seed"]' --> 'config["seed"]'
valosekj Mar 27, 2024
41653b0
Change use_distributed to False to prevent:
valosekj Mar 27, 2024
fbe6061
Use 'model.run_folder' instead of 'run_folder'
valosekj Mar 27, 2024
63d9226
Comment 'wandb_run'
valosekj Mar 27, 2024
248ef73
fix import typo
valosekj Mar 27, 2024
18b1d50
Add 'logger.info(f"Using device: {device}")'
valosekj Mar 27, 2024
f16d8a3
rename 'x' to 'batch_data' to make code easier to follow
valosekj Mar 27, 2024
fc1e07b
remove unused imports
valosekj Mar 27, 2024
606bc84
Add an example of the config YAML file using 'SmartFormatter'
valosekj Mar 27, 2024
995c71b
Comment depths, feature_size, and num_heads for the SwinUNETR model t…
valosekj Mar 27, 2024
7dd6711
fix merge conflicts
naga-karthik Mar 27, 2024
473ade4
fix merge conflicts
naga-karthik Mar 27, 2024
4932438
remove arg in EnsureTyped to fix issue with multi-gpu training
naga-karthik Mar 27, 2024
94e7658
add dice_score()
naga-karthik Mar 27, 2024
6828e8d
uncomment swinunetr params; improve logging info
naga-karthik Mar 27, 2024
efab1cd
remove unetr as model choice (not supported)
naga-karthik Mar 27, 2024
1a6ad29
remove local_rank as arg in favor of os.environ
naga-karthik Mar 27, 2024
f1c3c33
use sliding_window_inference for validation
naga-karthik Mar 27, 2024
72e5b21
unwrap model if DDP is used
naga-karthik Mar 27, 2024
e169037
log only if local_rank is 0
naga-karthik Mar 27, 2024
fa85dd5
fix CUDA error by 'spawn'
naga-karthik Mar 27, 2024
f9511fb
move everything to local_rank
naga-karthik Mar 27, 2024
619b03a
create run folders
naga-karthik Mar 27, 2024
8ae5470
minor code improvements
naga-karthik Mar 27, 2024
fdc8474
add usage example
naga-karthik Mar 27, 2024
93047c5
add adaptive wing loss for SC seg pre-training
naga-karthik Mar 27, 2024
e3ba28a
switch to AdapWingLoss
naga-karthik Mar 27, 2024
6c3a350
add example yaml config fille
naga-karthik Mar 28, 2024
eb70621
Remove 'directories' and 'dataset' keys
valosekj Mar 28, 2024
8cf1910
Remove 'spacing' key because resampling to 1.0x1.0x1.0 spacing is har…
valosekj Mar 28, 2024
f0bc514
Add 'train_batch_size' and 'val_batch_size' keys
valosekj Mar 28, 2024
bcd0df1
Add note for swinunetr that spatial dimensions of input image must be…
valosekj Mar 28, 2024
0390e24
Add depths, num_heads, and feature_size for swinunetr to the example …
valosekj Mar 28, 2024
ff2fa0d
Update optimizer in the example config
valosekj Mar 28, 2024
c8c4f70
Remove train_batch_size and val_batch_size -- they are not used. Inst…
valosekj Mar 28, 2024
92763e7
Add missing space to logger.info
valosekj Mar 28, 2024
0813889
rename pretraining/ to pretraining_and_finetuning/
naga-karthik Mar 28, 2024
c3508ff
rename main_supervised.py to main_supervised_pretraining.py
naga-karthik Mar 28, 2024
ad4b891
Merge branch 'nk/jv_vit_unetr_ssl' of https://github.com/ivadomed/mod…
naga-karthik Mar 28, 2024
86fb668
Plot and save input and output validation images to see how the model…
valosekj Mar 28, 2024
56f05de
Merge branch 'nk/jv_vit_unetr_ssl' of github.com:ivadomed/model-seg-d…
valosekj Mar 28, 2024
6d91c59
Add script to create MSD-style JSON datalist file for BIDS datasets.
valosekj Mar 28, 2024
0d19c50
Add the `whole-spine` dataset
valosekj Mar 28, 2024
460c227
Add note for canproco that we do not use sub-cal* subjects right now …
valosekj Mar 28, 2024
43b1687
Add note for whole-spine that we do not use this dataset right now du…
valosekj Mar 28, 2024
8e204bb
add function to plot slices
naga-karthik Mar 28, 2024
ef2b1f3
move plotting to plot_slices() in utils.py
naga-karthik Mar 28, 2024
730d146
crop around SC; resample to 0.8 iso; re-use transforms in validation
naga-karthik Mar 28, 2024
ab8b3cc
modify crop size -- leave 3rd dim. open
naga-karthik Mar 28, 2024
0d4b484
PEP8 formatting
valosekj Mar 28, 2024
605c414
Plot 10 equally spaced slices
valosekj Mar 28, 2024
b93837a
Plot all validation samples, not only the first one
valosekj Mar 28, 2024
b3227b1
explicitly mention that the script creates datasets for the SUPERVISE…
valosekj Mar 28, 2024
c0fe9f8
fix title placing
valosekj Mar 28, 2024
12ab8a0
Use alpha=0.5
valosekj Mar 28, 2024
448077b
Add slice number below each prediction slice
valosekj Mar 28, 2024
d1f2549
Reduce space between subplots
valosekj Mar 28, 2024
77b9e4e
Move slice number down
valosekj Mar 28, 2024
a2ee61e
Move master title down
valosekj Mar 28, 2024
fdb3ed7
Use dpi=200 to save the val figures
valosekj Mar 28, 2024
6f35733
change crop_pad_size from [64, 64, -1] to [64, 64, 64] to avoid the f…
valosekj Mar 28, 2024
148c1eb
fix bug swinunetr img_size arg -- set it to patch_size
naga-karthik Mar 29, 2024
0689fe4
PEP8 formatting
valosekj Mar 29, 2024
34a1cd7
PEP8 formatting
valosekj Mar 29, 2024
2a53225
Rename 'InitWeights_He' --> 'InitWeightsHe' to supress warning that a…
valosekj Mar 29, 2024
5287221
Logging refactoring: log everything into a log inside model.run_folder
valosekj Mar 29, 2024
6329c9e
PEP8
valosekj Mar 29, 2024
a498d25
move 'set_determinism(config["seed"])' up
valosekj Mar 29, 2024
e8b64f5
Move initializing of the distributed training process before DDP
valosekj Mar 29, 2024
7e595b4
Read log_dir from the model.log_dir before DDP
valosekj Mar 29, 2024
478b26d
Add 'logger.info("Getting datalists ...")'
valosekj Mar 29, 2024
6924ecd
Plot only a single random validation image each epoch to limit number…
valosekj Mar 29, 2024
642cd2c
Remove unused imports
valosekj Mar 29, 2024
6d91c44
Fix the logic to get random val sample to plot
valosekj Mar 29, 2024
e0a5259
Remove space
valosekj Mar 29, 2024
fed39b3
remove redundant if args.dist statements
naga-karthik Mar 29, 2024
39558c8
Add note why we use "if local_rank == 0 else None"
valosekj Mar 29, 2024
5d11d46
Add docstring for 'plot_slices'
valosekj Mar 29, 2024
09ac128
set 3rd dim to -1 when cropping/resizing
naga-karthik Mar 30, 2024
35cd1d1
add wip version of finetuning script
naga-karthik Mar 30, 2024
53e45d3
update finetuning transforms
naga-karthik Mar 30, 2024
2ef78ba
fix bug in replacing prefixess state_dict's keys when loading pretrai…
naga-karthik Mar 30, 2024
fb01dac
log train/val dice score to the terminal along with loss
naga-karthik Mar 30, 2024
d086d58
add RandCropByPosNegLabeld transform to finetuning
naga-karthik Mar 30, 2024
4afbf83
log train/val dice score to terminal in finetuning
naga-karthik Mar 30, 2024
c8ff911
improve logging based on finetuning or from scratch
naga-karthik Mar 30, 2024
8a7441d
fix bug in model loading
naga-karthik Mar 30, 2024
5cc9eb4
Add notes
valosekj Mar 30, 2024
eaa73bc
remove unused imports
valosekj Mar 30, 2024
bcf0025
improve 'raise ValueError' and 'logger.info' messages
valosekj Mar 30, 2024
70d4b3b
remove spaces before '...' during logging
valosekj Mar 30, 2024
a4f10b2
Add 'self.task' to the class 'BackboneModel' to track whether we do p…
valosekj Mar 30, 2024
07c9533
Use 'self.model_name' instead of 'model_name'
valosekj Mar 30, 2024
50d93a1
Use 'self.model_name' instead of 'model_name'
valosekj Mar 30, 2024
9198cb5
Add 'logger.info(f'Output directory: {self.log_dir}')'
valosekj Mar 30, 2024
321d858
Remove unused 'device' param from 'train_transforms'
valosekj Mar 30, 2024
0490041
Use """ instead of ''' for docstring
valosekj Mar 30, 2024
3fd0b8b
add epochs to run folder name
naga-karthik Apr 2, 2024
167b78a
Add 'task' input arg (choices: pretraining or finetuning) to the 'plo…
valosekj Apr 2, 2024
a80fbac
Add TODO that we may want to plot sagittal slices for sagittal images
valosekj Apr 2, 2024
5b231f3
If the label is empty, exit the function
valosekj Apr 2, 2024
76fbedf
Add logging
valosekj Apr 2, 2024
e588c46
Add 'num_slices_to_plot = len(slice_indices)'
valosekj Apr 2, 2024
a7d417d
Fix plotting logic
valosekj Apr 2, 2024
c579bf5
Improve logging message
valosekj Apr 2, 2024
2746ed0
Copy the config file to the log directory to keep track of the traini…
valosekj Apr 4, 2024
e14ef1c
Log the total number of the training and validation samples
valosekj Apr 4, 2024
285c00d
PEP8
valosekj Apr 4, 2024
6ffa6b2
PEP8
valosekj Apr 4, 2024
56de837
Add TODO to limit multi-processing warnings
valosekj Apr 4, 2024
4561f73
Plot training and validation loss and dice values
valosekj Apr 4, 2024
a4f3385
Add comments, clarify variable name
valosekj Apr 4, 2024
6821df0
Use '.detach().cpu()'
valosekj Apr 4, 2024
13b85c6
Make fontsize smaller
valosekj Apr 4, 2024
1e9e5b7
Plot training and validation loss also for finetuning
valosekj Apr 4, 2024
61a22ec
Make "--resume-from-checkpoint" and "--run-dir" input args properly w…
valosekj Apr 5, 2024
bb7cd2c
Add logging when the new best model saved
valosekj Apr 5, 2024
cbcf7ea
Comment 'transforms.CropForegroundd'
valosekj Apr 6, 2024
b95eb1f
Change crop_pad_size to [64, 160, 320] to not crop SC
valosekj Apr 6, 2024
3c219aa
Save the latest checkpoint.pth after each epoch
valosekj Apr 6, 2024
7038238
Remove logging validation loss when saving checkpoint.pth
valosekj Apr 6, 2024
f6f1aae
Add sci-zurich
valosekj Apr 7, 2024
94fdabf
add training params for monai-unet
naga-karthik Apr 8, 2024
14e7606
change max_epochs; update batch_size
naga-karthik Apr 8, 2024
d8e1dc3
uncomment affine transform
naga-karthik Apr 8, 2024
cb89143
use DiceCELoss instead of AdapWing
naga-karthik Apr 8, 2024
6be6e64
move normalization up before calculating loss
naga-karthik Apr 8, 2024
aac5bb0
Fix mode for transforms.RandAffined
valosekj Apr 8, 2024
2d9a381
uncomment transforms.CropForegroundd
valosekj Apr 8, 2024
5540fa4
Save 'checkpoint.pth' after each epoch
valosekj Apr 8, 2024
783117a
Make "--resume-from-checkpoint" and "--run-dir" input args properly w…
valosekj Apr 8, 2024
04dfeb8
limit the number of slices to plot to 10
valosekj Apr 8, 2024
fadc058
Change transforms.Spacingd to 1.0 iso
valosekj Apr 8, 2024
3f23f35
comment transforms.CropForegroundd. Again :-D
valosekj Apr 8, 2024
ce81111
uncomment transforms
valosekj Apr 8, 2024
3830cb0
change crop_pad_size and epoch nb
valosekj Apr 8, 2024
85d1785
do not hardcode 'config["model"]["monai-unet"]["spatial_dims"]'
valosekj Apr 8, 2024
4ff571e
fix plotting logic
valosekj Apr 8, 2024
ff8673e
Run 'plot_slices' only if local_rank == 0
valosekj Apr 9, 2024
d3ab5ce
Add 'local_rank' to BackboneModel to do not duplicate logger.info
valosekj Apr 9, 2024
61fdf48
Fix 'local_rank' passing to 'load_data' to do not duplicate logger.info
valosekj Apr 9, 2024
7f4538e
Use 'local_rank == 0' to not duplicate logger.info
valosekj Apr 9, 2024
e4803bb
Copy the transforms.py file to the log directory to keep track of the…
valosekj Apr 10, 2024
6c0fb6d
Use 'if local_rank == 0 else None' and add 'logger.info'
valosekj Apr 10, 2024
3648d91
Rotate val samples by 90 degrees before plotting
valosekj Apr 10, 2024
4b1234d
logg loss function
valosekj Apr 10, 2024
fe3fb74
Add dropout to 'monai-unet'
valosekj Apr 10, 2024
4800530
Add a comment why we use 320 channels
valosekj Apr 10, 2024
61f2908
Copy datalists to the log directory
valosekj Apr 11, 2024
d5be144
Fix typo
valosekj Apr 11, 2024
e2e4c07
Add 'git annex get $(find . -name "*T2w_lesion-manual.nii.gz")'
valosekj Apr 11, 2024
0d8aac8
Rotate by 270 degrees instead of 90
valosekj Apr 11, 2024
d9da21a
Add attentionunet
valosekj Apr 11, 2024
4f8415f
Add attentionunet
valosekj Apr 12, 2024
4d429fc
Fix suptitle position
valosekj Apr 12, 2024
394c3ef
remove num_res_units for AttentionUnet
valosekj Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions configs/train_supervised.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
seed: 30
save_test_preds: True

preprocessing:
# Center crop/pad images to the specified size. (NOTE: done after resampling to 1.0x1.0x1.0)
# values correspond to R-L, A-P, I-S axes of the image after 1mm isotropic resampling.
# NOTE for swinunetr: spatial dimensions of input image must be divisible by 32, i.e., we cannot you, for example, 48x48x48
crop_pad_size: [64, 64, -1]
patch_size: [64, 64, 64]

opt:
name: adam
lr: 0.001
max_epochs: 200
warmup_epochs: 20
batch_size: 2
# Interval between validation checks in epochs
check_val_every_n_epochs: 5
# Early stopping patience (this is until patience * check_val_every_n_epochs)
early_stopping_patience: 20


model:
# Model architecture to be used for training (also to be specified as args in the command line)
nnunet:
in_channels: 1
out_channels: 1
# NOTE: these info are typically taken from nnUNetPlans.json (if an nnUNet model is trained)
base_num_features: 32
max_num_features: 320
n_conv_per_stage_encoder: [2, 2, 2, 2, 2, 2]
n_conv_per_stage_decoder: [2, 2, 2, 2, 2]
pool_op_kernel_sizes: [
[1, 1, 1],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[1, 2, 2]
]
enable_deep_supervision: True

swinunetr:
in_channels: 1
out_channels: 1
spatial_dims: 3
depths: [2, 2, 2, 2]
num_heads: [3, 6, 12, 24] # number of heads in multi-head Attention
feature_size: 36
use_pretrained: False
85 changes: 84 additions & 1 deletion dataset-conversion/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,84 @@
This folder contains various scripts used for dataset conversion from BIDS to nnUNet/MONAI formats.
# Pre-training on multiple datasets

### Download datasets

Download T2w images and spinal cord segmentations for the following datasets.

```commandline
cd ~/duke/temp/janvalosek/ssl_pretraining_multiple_datasets
```

`spine-generic multi-subject` (n=267)

```commandline
git clone https://github.com/spine-generic/data-multi-subject
cd data-multi-subject
git checkout sb/156-add-preprocessed-images
git annex get $(find . -name "*space-other_T2w.nii.gz")
git annex get $(find . -name "*space-other_T2w_label-SC_seg.nii.gz")
```

`whole-spine` (n=45)

NOTE: we do not use this dataset right now due to https://github.com/neuropoly/data-management/issues/306

```commandline
git clone [email protected]:datasets/whole-spine
cd whole-spine
git annex dead here
git annex get $(find . -name "*T2w.nii.gz")
git annex get $(find . -name "*T2w_seg.nii.gz")
```

`canproco` (n=321)

NOTE: we do not use sub-cal* subjects right now due to https://github.com/neuropoly/data-management/issues/305

```commandline
git clone [email protected]:datasets/canproco
cd canproco
git annex dead here
git annex get $(find . -name "*ses-M0_T2w.nii.gz")
git annex get $(find . -name "*ses-M0_T2w_seg-manual.nii.gz")
```

`sci-colorado` (n=80)

```commandline
git clone [email protected]:datasets/sci-colorado
cd sci-colorado
git annex dead here
git annex get $(find . -name "*T2w.nii.gz")
git annex get $(find . -name "*T2w_seg-manual.nii.gz")
```

`dcm-zurich` (n=135)

```commandline
git clone [email protected]:datasets/dcm-zurich
cd dcm-zurich
git annex dead here
git annex get $(find . -name "*acq-axial_T2w.nii.gz")
git annex get $(find . -name "*acq-axial_T2w_label-SC_mask-manual.nii.gz")
```

`sci-paris` (n=14)

```commandline
git clone [email protected]:datasets/sci-paris
cd sci-paris
git annex dead here
git annex get $(find . -name "*T2w.nii.gz")
git annex get $(find . -name "*T2w_seg.nii.gz")
```

### Create MSD-style JSON datalists

```commandline
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data data-multi-subject --dataset-name spine-generic --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data whole-spine --dataset-name whole-spine --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data canproco --dataset-name canproco --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data sci-colorado --dataset-name sci-colorado --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data dcm-zurich --dataset-name dcm-zurich --path-out . --split 0.8 0.2 --seed 42
python /Users/user/code/model-seg-dcm/vit_unetr_ssl/create_msd_data.py --path-data sci-paris --dataset-name sci-paris --path-out . --split 0.8 0.2 --seed 42
```
184 changes: 184 additions & 0 deletions dataset-conversion/create_msd_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
Create MSD-style JSON datalist file for BIDS datasets.
The following two keys are included in the JSON file: 'image' and 'label_sc'.

NOTE: the script is meant to be used for SUPERVISED pre-training, meaning that the dataset is split into training and
validation.
In other words, NO testing set is created.

The script has to be run for each dataset separately, meaning that one JSON file is created for each dataset.

Example usage:
python create_msd_data.py
--path-data /Users/user/data/spine-generic
--dataset-name spine-generic
--path-out /Users/user/data/spine-generic

python create_msd_data.py
--path-data /Users/user/data/dcm-zurich
--dataset-name dcm-zurich
--path-out /Users/user/data/dcm-zurich
"""

import os
import json
import argparse
from pathlib import Path
from loguru import logger
from sklearn.model_selection import train_test_split

contrast_dict = {
'spine-generic': 'space-other_T2w', # iso T2w (preprocessed data)
'whole-spine': 'T2w', # iso T2w
'canproco': 'ses-M0_T2w', # iso T2w (session M0)
'dcm-zurich': 'acq-axial_T2w', # axial T2w
'sci-paris': 'T2w', # iso T2w
'sci-colorado': 'T2w' # axial T2w
}

# Spinal cord segmentation file suffixes for different datasets
sc_fname_suffix_dict = {
'spine-generic': 'label-SC_seg',
'whole-spine': 'seg',
'canproco': 'seg-manual',
'dcm-zurich': 'label-SC_mask-manual',
'sci-paris': 'seg-manual',
'sci-colorado': 'seg-manual'
}


def get_parser():
parser = argparse.ArgumentParser(description='Create MSD-style JSON datalist file for BIDS datasets.')

parser.add_argument('--path-data', required=True, type=str,
help='Path to BIDS dataset. Example: /Users/user/data/dcm-zurich')
parser.add_argument('--dataset-name', required=True, type=str,
help='Name of the dataset. Example: spine-generic or dcm-zurich.')
parser.add_argument('--path-out', type=str, required=True,
help='Path to the output directory where dataset json is saved')
parser.add_argument('--split', nargs='+', type=float, default=[0.8, 0.2],
help='Ratios of training and validation 0-1. '
'Example: --split 0.8 0.2')
parser.add_argument('--seed', default=42, type=int, help="Seed for reproducibility")

return parser


def main():
args = get_parser().parse_args()

dataset = os.path.abspath(args.path_data)
dataset_name = args.dataset_name
train_ratio, val_ratio = args.split
seed = args.seed
path_out = os.path.abspath(args.path_out)

# Check if the dataset name is valid
if dataset_name not in contrast_dict.keys():
raise ValueError(f"Dataset name {dataset_name} is not valid. Choose from {contrast_dict.keys()}")

contrast = contrast_dict[dataset_name]
sc_fname_suffix = sc_fname_suffix_dict[dataset_name]
datalist_fname = f"{dataset_name}_seed{seed}"

train_images, val_images = {}, {}

# For spine-generic, we add 'derivatives/data_preprocessed' to the path to use the preprocessed data with the same
# resolution and orientation as the spinal cord segmentations
if dataset_name == 'spine-generic':
root = Path(dataset) / 'derivatives/data_preprocessed'
else:
root = Path(dataset)
# Path to 'derivatives/labels with spinal cord segmentations
labels = Path(dataset) / 'derivatives/labels'

# Check if the dataset path exists
if not os.path.exists(root):
raise ValueError(f"Path {root} does not exist.")
if not os.path.exists(labels):
raise ValueError(f"Path {labels} does not exist.")

logger.info(f"Root path: {root}")
logger.info(f"Labels path: {labels}")

# get recursively all the subjects from the root folder
subjects = [sub for sub in os.listdir(root) if sub.startswith("sub-")]

# Get the training and validation splits
# Note: we are doing SSL pre-training, so we don't need test set
tr_subs, val_subs = train_test_split(subjects, test_size=val_ratio, random_state=args.seed)

# recursively find the spinal cord segmentation files under 'derivatives/labels' for training and validation
# subjects
tr_seg_files = [str(path) for sub in tr_subs for path in
Path(labels).rglob(f"{sub}_{contrast}_{sc_fname_suffix}.nii.gz")]
val_seg_files = [str(path) for sub in val_subs for path in
Path(labels).rglob(f"{sub}_{contrast}_{sc_fname_suffix}.nii.gz")]

# update the train and validation images dicts with the key as the subject and value as the path to the subject
train_images.update({sub: os.path.join(root, sub) for sub in tr_seg_files})
val_images.update({sub: os.path.join(root, sub) for sub in val_seg_files})

logger.info(f"Found subjects in the training set: {len(train_images)}")
logger.info(f"Found subjects in the validation set: {len(val_images)}")

# keys to be defined in the dataset_0.json
params = {}
params["dataset_name"] = dataset_name
params["contrast"] = contrast
params["labels"] = {
"0": "background",
"1": "sc-seg"
}
params["modality"] = {
"0": "MRI"
}
params["numTraining"] = len(train_images)
params["numValidation"] = len(val_images)
params["seed"] = args.seed
params["tensorImageSize"] = "3D"

train_images_dict = {"training": train_images}
val_images_dict = {"validation": val_images}

all_images_list = [train_images_dict, val_images_dict]

for images_dict in all_images_list:

for name, images_list in images_dict.items():

temp_list = []
for label in images_list:

temp_data_t2w = {}
# create the image path by replacing the label path
if dataset_name == 'spine-generic':
temp_data_t2w["image"] = label.replace(f'_{sc_fname_suffix}', '').replace('labels',
'data_preprocessed')
else:
temp_data_t2w["image"] = label.replace(f'_{sc_fname_suffix}', '').replace('/derivatives/labels', '')

# Spinal cord segmentation file
temp_data_t2w["label_sc"] = label

if os.path.exists(temp_data_t2w["label_sc"]) and os.path.exists(temp_data_t2w["image"]):
temp_list.append(temp_data_t2w)
else:
logger.info(f"Either image/label does not exist.")

params[name] = temp_list
logger.info(f"Number of images in {name} set: {len(temp_list)}")

final_json = json.dumps(params, indent=4, sort_keys=False)
if not os.path.exists(path_out):
os.makedirs(path_out, exist_ok=True)

jsonFile = open(path_out + "/" + f"{datalist_fname}.json", "w")
jsonFile.write(final_json)
jsonFile.close()
print(f"JSON file saved to {path_out}/{datalist_fname}.json")


if __name__ == "__main__":
main()

50 changes: 50 additions & 0 deletions pretraining_and_finetuning/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from monai.data import (DataLoader, DistributedSampler, CacheDataset, load_decathlon_datalist)

from transforms import train_transforms, val_transforms


def load_data(datalists_paths, train_batch_size, val_batch_size, num_workers=8, use_distributed=False,
crop_size=(64, 192, 320), patch_size=(64, 64, 64), device="cuda", task="pretraining"):
"""
Return train and val dataloaders from datalist json file
:param datalists_paths: path(s) to the datalist json file(s)
:param train_batch_size: batch size for training dataloader
:param val_batch_size: batch size for validation dataloader
:param num_workers: number of workers for dataloader
:param use_distributed: whether to use distributed training
:param crop_size: crop size; e.g., (64, 192, 320)
:param patch_size: patch size; e.g., (64, 64, 64)
:param device: device to load data and apply transforms
:param task: task for train/val transforms; choices: pretraining or finetuning
"""
train_datalist = []
val_datalist = []
for datalist_path in datalists_paths:
train_datalist += load_decathlon_datalist(data_list_file_path=datalist_path, data_list_key="training")
val_datalist += load_decathlon_datalist(data_list_file_path=datalist_path, data_list_key="validation")

train_tfs = train_transforms(crop_size, patch_size, device=device, task=task)
val_tfs = val_transforms(crop_size, task=task)

# training dataset
train_ds = CacheDataset(data=train_datalist, transform=train_tfs, cache_rate=0.5, num_workers=4,
copy_cache=False)
# validation dataset
val_ds = CacheDataset(data=val_datalist, transform=val_tfs, cache_rate=0.25, num_workers=4,
copy_cache=False)

if use_distributed:
train_sampler = DistributedSampler(dataset=train_ds, even_divisible=True, shuffle=True)
val_sampler = DistributedSampler(dataset=val_ds, even_divisible=True, shuffle=False)
else:
train_sampler = None
val_sampler = None

# training dataloader
train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=num_workers,
pin_memory=True, sampler=train_sampler, persistent_workers=True)
# validation dataloader
val_loader = DataLoader(val_ds, batch_size=val_batch_size, shuffle=False, num_workers=num_workers,
pin_memory=True, sampler=val_sampler, persistent_workers=True)

return train_loader, val_loader
Loading