Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supervised pre-training on SC segmentations using swinunetr #12

Open
valosekj opened this issue Apr 5, 2024 · 4 comments
Open

Supervised pre-training on SC segmentations using swinunetr #12

valosekj opened this issue Apr 5, 2024 · 4 comments

Comments

@valosekj
Copy link
Member

valosekj commented Apr 5, 2024

Description

This issue summarizes some early experiments with supervised pre-training SC segmentations.

WIP branch: nk/jv_vit_unetr_ssl
Pre-training script: pretraining_and_finetuning/main_supervised_pretraining.py

Experiments

Unlike SSL experiments done in #7 and #9, the pre-training done in this issue is supervised, done on SC segmentations.

T2w images for the supervised pre-training come from 5 datasets (canproco, dcm-zurich, sci-colorado, sci-paris, spine-generic multi-subject). Number of training samples: 654. Number of validation samples: 163. Details about the images are provided in dataset-conversion/README.md.

I'm currently training two different swinunetr models. Both with crop_pad_size: [64, 160, 320] and patch_size: [64, 64, 64].

Experiment 1 - Model with CropForegroundd - multiple datasets

This model uses transforms.CropForegroundd(keys=all_keys, source_key="label_sc") to crop everything outside the SC mask. See pretraining_and_finetuning/transforms.py.

GIF of validation samples

gif

Note: validation is done every 5 epochs

Validation hard dice dropped to zero after ~100 epochs:

loss_plots

loss_plots

Note: validation is done every 5 epochs

Experiment 2 - Model without CropForegroundd - multiple datasets

This model does NOT use transforms.CropForegroundd(keys=all_keys, source_key="label_sc").

GIF of validation samples

gif

Note: validation is done every 5 epochs

loss_plots

loss_plots

Note: validation is done every 5 epochs

Model training crashed due to OSErrror: [Errno 112] Host is down ... (possibly because I'm still using duke/temp to load data from?). So I resumed the training from the best checkpoint (~65 epoch). Training resumed but then the validation hard dice dropped to zero:

loss_plots after resume

loss_plots

Note: validation is done every 5 epochs

valosekj added a commit that referenced this issue Apr 6, 2024
Pre-training with 'transforms.CropForegroundd' was crashing to zero. Details: #12
@valosekj
Copy link
Member Author

valosekj commented Apr 6, 2024

Since both swinunetr with and without CropForegroundd crashed to zero when trained on T2w images from multiple datasets (details in the comment above), I tried to train swinunetr on a single dataset (spine-generic multi-subject). And training finished successfully!

Experiment 3 - Model with CropForegroundd - spine-generic only

This model used transforms.CropForegroundd(keys=all_keys, source_key="label_sc").
Number of training samples: 213. Number of validation samples: 54.

loss_plots

loss_plots

Note: validation is done every 5 epochs

GIF of validation samples

gif

Note: validation is done every 5 epochs

Experiment 4 - Model without CropForegroundd - spine-generic only

This model did NOT use transforms.CropForegroundd(keys=all_keys, source_key="label_sc").
Number of training samples: 213. Number of validation samples: 54.

loss_plots

loss_plots

Note: validation is done every 5 epochs

GIF of validation samples

gif

Note: validation is done every 5 epochs

Notice that the model is predicting beside SC also other components, for example:

val_00199_048 copy

Conclusion

I originally thought that collapsing the training to zero was due to using transforms.CropForegroundd(keys=all_keys, source_key="label_sc"). But when I trained only on spine-generic, both training with and without transforms.CropForegroundd finished successfully; it seems that training collapsing may have originated from using images from multiple images.

@naga-karthik
Copy link
Member

it seems that training collapsing may have originated from using images from multiple images.

I kind of don't agree with this because I have trained on spine-generic and basel-mp2rage for contrast-agnostic and it worked fine. This crashing you report on multiple datasets might be an issue with the specific experiment -- once the training stopped and resumed from checkpoint -- there might have been an issue with loading the checkpoint and resuming training.

if we compare: (1) spine-generic with CropForegroundd, (2) spine-generic + lesion datasets with CropForegroundd, while ensuring that the training did not crash at any point -- we might have different conclusion!

@valosekj
Copy link
Member Author

valosekj commented Apr 8, 2024

Thanks @naga-karthik!

(2) spine-generic + lesion datasets with CropForegroundd

I tried the following experiment:

swinunetr with CropForegroundd pre-trained on three datasets (spine-generic multi-subject, dcm-zurich, and sci-paris) for SC seg. And training finished without any crashes!

loss_plots

loss_plots

Note: validation is done every 5 epochs

GIF of validation samples

gif

Note: validation is done every 5 epochs

So, now we have several pre-trained models, I'm moving to fine-tuning on lesions!


btw, hard to say what was the origin of training crashing in #12 (comment). I'll try to figure this out later.

@naga-karthik
Copy link
Member

Do you also have some pre-trained nnunet or monai-unet models?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants