Skip to content

Training code for the Hippocampal Segmentation Factory

License

Notifications You must be signed in to change notification settings

neurospin/hsf_train

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

42 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

HSF_TRAIN

Empower MRI Insights: Pristine Precision, Effortless Integration

license last-commit repo-top-language repo-language-count

Developed with the software and tools below.

YAML Python ONNX


Quick Links


Overview

The hsf_train project is a sophisticated, configurable framework for fine-tuning, training, and exporting deep learning models tailored to MRI data segmentation, specifically focusing on hippocampal regions. Utilizing cutting-edge techniques, including custom segmentation models, data augmentation, and loss functions adapted from nnUNet, it strives for precision in medical imaging tasks. Key features include integration with SparseML for optimizationβ€”pruning and quantizationβ€”leading to deployment-ready models via ONNX export, and comprehensive experiment tracking with Weights & Biases. This project seamlessly combines flexible model architecture configurations, dataset management, and advanced training environments, offering an end-to-end solution for enhancing neural network performance and efficiency in medical image analysis.


Features

Feature Description
βš™οΈ Architecture The project follows a modular design pattern and leverages the PyTorch Lightning framework. It includes configurable models, data loaders, and training configurations. The codebase supports fine-tuning of segmentation models, model export in ONNX format, and integration with SparseML for model optimization.
πŸ”© Code Quality The codebase exhibits good code quality and follows Python coding conventions. The code is well-structured and organized into modules and classes. It uses meaningful variable and function names, and includes type hints to improve readability and maintainability. The use of external libraries and dependencies is appropriate and follows best practices.
πŸ“„ Documentation The project has well-documented code. The codebase includes inline comments explaining various functions and modules. Additionally, it provides configuration files with detailed descriptions of their purpose and available options. However, there could be room for improvement in terms of providing more comprehensive documentation and usage examples.
πŸ”Œ Integrations The codebase integrates with various external libraries and tools such as wandb (Weights and Biases), torchio, and SparseML. These integrations enhance the functionality of the project, enabling efficient logging, visualization, data augmentation, and model optimization.
🧩 Modularity The codebase demonstrates modularity and reusability. It separates concerns into different files and modules, allowing for easy extension and customization. The configuration files provide a flexible way to adjust various settings, making the codebase adaptable to different use cases and datasets.
πŸ§ͺ Testing The project does not explicitly mention testing frameworks or tools. However, given its modular structure and code quality, it would be feasible to incorporate unit tests using frameworks like pytest or unittest to ensure code correctness and prevent regressions.
⚑️ Performance The performance of the codebase depends on the specific models and datasets used. The use of libraries like Lightning and SparseML suggests a focus on efficient deep learning training and optimization. However, a detailed evaluation of efficiency, speed, and resource usage would require benchmarking and profiling specific use cases.
πŸ›‘οΈ Security The project does not explicitly mention security measures. However, following Python security best practices, such as using secure dependencies and handling data securely, can help ensure data protection and access control. Regular dependency updates and code reviews can also mitigate security vulnerabilities.
πŸ“¦ Dependencies The project relies on external libraries and dependencies such as rich, text, pymia, wandb, torch, xxhash, python, onnxruntime, sparseml, lightning, torchio, python-dotenv, yaml, and py. The 'requirements.txt' file specifies these dependencies, ensuring easy setup and reproducibility.

Repository Structure

└── hsf_train/
    β”œβ”€β”€ LICENSE
    β”œβ”€β”€ conf
    β”‚   β”œβ”€β”€ config.yaml
    β”‚   β”œβ”€β”€ datasets
    β”‚   β”‚   β”œβ”€β”€ all.yaml
    β”‚   β”‚   └── custom_dataset.yaml
    β”‚   β”œβ”€β”€ finetuning
    β”‚   β”‚   └── default.yaml
    β”‚   β”œβ”€β”€ lightning
    β”‚   β”‚   └── default.yaml
    β”‚   β”œβ”€β”€ logger
    β”‚   β”‚   └── wandb.yaml
    β”‚   └── models
    β”‚       └── default.yaml
    β”œβ”€β”€ finetune.py
    β”œβ”€β”€ hsftrain
    β”‚   β”œβ”€β”€ callbacks.py
    β”‚   β”œβ”€β”€ data
    β”‚   β”‚   β”œβ”€β”€ __init__.py
    β”‚   β”‚   β”œβ”€β”€ dataloader.py
    β”‚   β”‚   └── loader.py
    β”‚   β”œβ”€β”€ exporter.py
    β”‚   β”œβ”€β”€ models
    β”‚   β”‚   β”œβ”€β”€ __init__.py
    β”‚   β”‚   β”œβ”€β”€ blocks.py
    β”‚   β”‚   β”œβ”€β”€ helpers.py
    β”‚   β”‚   β”œβ”€β”€ layers.py
    β”‚   β”‚   β”œβ”€β”€ losses.py
    β”‚   β”‚   β”œβ”€β”€ models.py
    β”‚   β”‚   β”œβ”€β”€ optimizer.py
    β”‚   β”‚   β”œβ”€β”€ scheduler.py
    β”‚   β”‚   └── types.py
    β”‚   └── utils.py
    β”œβ”€β”€ main.py
    β”œβ”€β”€ requirements.txt
    β”œβ”€β”€ scripts
    β”‚   └── ckpt_to_onnx.py
    └── sparseml
        β”œβ”€β”€ finetuning.yaml
        └── scratch.yaml

Modules

.
File Summary
finetune.py This script, finetune.py, fine-tunes a pre-trained segmentation model for MRI images with custom configurations, trains it with specific data augmentation, and logs performance metrics using Wandb. It supports exporting the fine-tuned model to ONNX format, with optional conversion to DeepSparse format if SparseML is utilized. Critical features include data preparation, model training, checkpointing, and exporting, aligning with the repository's focus on flexible, efficient deep learning workflows in medical imaging.
main.py The main.py script serves as the central entry point for a neural network training pipeline focused on MRI data segmentation within the larger hsf_train repository. It integrates data preprocessing, augmentation, and postprocessing workflows, utilizes a custom SegmentationModel with FocalTversky loss, and supports experiment tracking via Weights & Biases. Additionally, it includes optional SparseML integration for model pruning and quantization, leading to ONNX export functionalities for model deployment. This script emphasizes automated, configurable experimentation and model optimization for enhanced deployment-ready neural network architectures.
requirements.txt The requirements.txt file specifies dependencies essential for running the HSF Train repository, indicating the project's reliance on libraries such as PyTorch, Lightning, and ONNX for deep learning, model training, serialization, and tracking experiments.
conf
File Summary
config.yaml This config.yaml integrates core components for the hsf_train repository, setting defaults for datasets, model architecture, training parameters, and logging, ensuring streamlined configuration and modular adaptation within the project's architecture.
conf.datasets
File Summary
all.yaml The conf/datasets/all.yaml file serves as a comprehensive configuration for dataset management, specifically tailored for hippocampal MRI data across various domains such as hiplay, memodev, and more. It outlines data paths, MRI patterns, label mappings, and operation parameters like batch size and train/test split ratios, ensuring datasets are standardized and efficiently processed within the hsf_train repository's architecture for training and fine-tuning neural network models on hippocampal segmentation tasks.
custom_dataset.yaml This configuration within hsf_train repository specifies loading and processing details for a custom dataset, including paths, batch size, worker count, memory pinning, partitioning ratios, and MRI image-label pairings for hippocampus study, streamlining dataset integration in the ML model training process.
conf.finetuning
File Summary
default.yaml This default.yaml within conf/finetuning/ defines parameters for fine-tuning processes in the hsf_train repository, focusing on decoder models. It specifies the model depth, unfreeze frequency for layers, and output channels, directly impacting model adaptation and performance optimization tasks.
conf.lightning
File Summary
default.yaml The conf/lightning/default.yaml configures the training environment for the hsf_train repository, setting GPU acceleration, automatic strategy selection, mixed precision, and training parameters including epochs and gradient accumulation.
conf.logger
File Summary
wandb.yaml This YAML config file for Weights & Biases (Wandb) is integral to the repository's logging framework, setting up structured experiment tracking and visualization within an ML pipeline.
conf.models
File Summary
default.yaml This configuration file defines default hyperparameters for a 3D Residual U-Net model within the repository's deep learning training framework, including architecture specifics and training options.
hsftrain
File Summary
callbacks.py The callbacks.py in hsf_train enables SparseML integration for optimizing model training and exporting to ONNX. It facilitates training with a specified SparseML recipe, ensuring compatibility with single optimizer setups and supports model finalization and ONNX export with customizable settings for batch normalization fusing and QAT conversion.
exporter.py The exporter.py module provides functionality to transform PyTorch models into ONNX format, emphasizing support for various ONNX versions and manipulation of models for optimal export handling, such as disabling batch norm fusing and adjusting quantization settings for ONNX compatibility. It integrates with the broader architecture by using configurations from config.yaml and potentially affecting model training and evaluation processes through improved model interchangeability and deployment-ready formats.
utils.py The utils.py within the hsf_train repository is a utility module providing file integrity verification and dynamic model fetching capabilities. It ensures downloaded models match their expected hashes, offering an automated mechanism for maintaining model version control and integrity within this machine learning framework. This process is vital for the seamless operation and reliability of model training and inference processes in the repository's architecture.
hsftrain.models
File Summary
blocks.py The blocks.py module within the hsf_train repository defines custom neural network blocks using PyTorch, central to building model architectures. This component is pivotal for model customization and optimization in the project's deep learning framework.
helpers.py This helpers.py module provides utility functions for models in the hsf_train repository. It calculates learnable model parameters and dynamically computes feature numbers per level, aiding in model adaptability and configuration optimization within the repository's machine learning architecture.
layers.py The SwitchNorm3d class in hsftrain/models/layers.py introduces an adaptive normalization layer for 3D data, capable of dynamically selecting the best normalization method (Instance, Layer, or Batch Normalization) during training, optimizing for diverse architectures within the parent repository's machine learning framework focused on handling volumetric data.
losses.py This code snippet, part of the hsf_train repository's neural network model architecture, introduces advanced loss functions tailored for segmentation tasks. It directly borrows from nnUNet, featuring methods like TverskyLoss, FocalTversky_loss, and utility functions to process tensor operations. These are pivotal for enhancing model training by focusing on minimizing segmentation errors, integrating novel mechanisms like forgiving loss to adjust model sensitivity towards specific types of prediction errors.
models.py The models.py file is central to defining the machine learning model architecture and evaluation metrics within the hsf_train repository. It leverages the Lightning framework for model training lifecycle management and utilizes the pymia library for defining evaluation criteria. This file directly supports model experimentation and optimization by providing a structured model definition and evaluation mechanism, crucial for the repository's goal of fine-tuning and evaluating models efficiently.
optimizer.py The optimizer.py within hsf_train repository introduces the AdamP optimization algorithm, enhancing training stability and performance for deep learning models by integrating innovations like gradient centralization, adaptive gradient clipping, and optional adanorm adaptation for more controlled weight updates, fitting harmoniously into the model training framework directed by configurations in its parent architecture.
scheduler.py This code, part of the hsf_train repository, implements a learning rate scheduler with options for linear warmup and cosine annealing or linear decay. It manages learning rate adjustments over epochs to optimize training, crucial for the training pipeline's efficiency and efficacy in model performance.
types.py This code snippet, part of the hsf_train repository's models module, defines crucial type aliases and structures for handling parameters, optimization, and loss computation in model training. It supports flexible, type-safe operations across the training pipeline by standardizing parameter, state, and loss representations.
scripts
File Summary
ckpt_to_onnx.py This script converts trained segmentation models from PyTorch checkpoints to ONNX format, enabling interoperability and optimization for deployment in diverse environments. It utilizes model parameters and custom loss functions defined within the repository's architecture, emphasizing the integration with the broader machine learning pipeline.
sparseml
File Summary
finetuning.yaml The sparseml/finetuning.yaml serves as a configuration file within the hsf_train repository, outlining the finetuning phase's schedule. It specifies the number of training epochs, introduces pruning and quantization periods to optimize model size and inference speed, and defines the progression and intensity of model sparsity and quantization efforts. Its role is crucial for enhancing model performance and efficiency in the repository's overarching machine learning workflow.
scratch.yaml The sparseml/scratch.yaml file defines a training schedule for neural network sparsity and quantization within the larger machine learning training framework, aiming for model efficiency. It schedules epoch-based pruning to reduce model size before applying quantization techniques for further compression, crucial for deployment efficiency. This configuration impacts overall training, inferencing speed, and model deployability, especially in resource-constrained environments.

Getting Started

Requirements

Ensure you have the following dependencies installed on your system:

  • Python: version 3.8.0 or higher

Installation

  1. Clone the hsf_train repository:
git clone https://github.com/clementpoiret/hsf_train
  1. Change to the project directory:
cd hsf_train
  1. Install the dependencies:
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -r requirements.txt
  1. Set up the environment variables for Weights & Biases:
touch .env
echo "WANDB_API_KEY=your_api_key" > .env

Running hsf_train

Use the following command to run train a new model:

python main.py

Example command to fine-tune a model:

python finetune.py datasets=custom_dataset finetuning.out_channels=8 models.lr=1e-3 models.use_forgiving_loss=False

Changelog

The implementation of the hsf_train repository is slightly different for the original paper as it proposes several improvements and optimizations.

Here are the main changes:

  • All downsampling operations are switched from max-pooling to strided convolutions,
  • ReLU activation functions are replaced by GELU,
  • The model is trained with another optimizer, AdamP, which is a variant of Adam,
  • Beyond AdamP, the learning rate is scheduled using a cosine annealing with warmup,
  • AdamP is further improved with Adaptive Gradient Clipping and Demon.

Contributing

Contributions are welcome! Here are several ways you can contribute:

  • Submit Pull Requests: Review open PRs, and submit your own PRs.
  • Report Issues: Submit bugs found or log feature requests for the hsf_train project.
Contributing Guidelines
  1. Fork the Repository: Start by forking the project repository to your github account.
  2. Clone Locally: Clone the forked repository to your local machine using a git client.
    git clone https://github.com/clementpoiret/hsf_train
  3. Create a New Branch: Always work on a new branch, giving it a descriptive name.
    git checkout -b new-feature-x
  4. Make Your Changes: Develop and test your changes locally.
  5. Commit Your Changes: Commit with a clear message describing your updates.
    git commit -m 'Implemented new feature x.'
  6. Push to GitHub: Push the changes to your forked repository.
    git push origin new-feature-x
  7. Submit a Pull Request: Create a PR against the original project repository. Clearly describe the changes and their motivations.

Once your PR is reviewed and approved, it will be merged into the main branch.


License

This project is protected under the MIT License.


Citing HSF

If you use HSF in your research, please cite the following paper:

@ARTICLE{10.3389/fninf.2023.1130845,
    AUTHOR={Poiret, Clement and Bouyeure, Antoine and Patil, Sandesh and Grigis, Antoine and Duchesnay, Edouard and Faillot, Matthieu and Bottlaender, Michel and Lemaitre, Frederic and Noulhiane, Marion},
    TITLE={A fast and robust hippocampal subfields segmentation: HSF revealing lifespan volumetric dynamics},
	JOURNAL={Frontiers in Neuroinformatics},
	VOLUME={17},
	YEAR={2023},
	URL={https://www.frontiersin.org/articles/10.3389/fninf.2023.1130845},
	DOI={10.3389/fninf.2023.1130845},
	ISSN={1662-5196},
    ABSTRACT={The hippocampal subfields, pivotal to episodic memory, are distinct both in terms of cyto- and myeloarchitectony. Studying the structure of hippocampal subfields in vivo is crucial to understand volumetric trajectories across the lifespan, from the emergence of episodic memory during early childhood to memory impairments found in older adults. However, segmenting hippocampal subfields on conventional MRI sequences is challenging because of their small size. Furthermore, there is to date no unified segmentation protocol for the hippocampal subfields, which limits comparisons between studies. Therefore, we introduced a novel segmentation tool called HSF short for hippocampal segmentation factory, which leverages an end-to-end deep learning pipeline. First, we validated HSF against currently used tools (ASHS, HIPS, and HippUnfold). Then, we used HSF on 3,750 subjects from the HCP development, young adults, and aging datasets to study the effect of age and sex on hippocampal subfields volumes. Firstly, we showed HSF to be closer to manual segmentation than other currently used tools (p < 0.001), regarding the Dice Coefficient, Hausdorff Distance, and Volumetric Similarity. Then, we showed differential maturation and aging across subfields, with the dentate gyrus being the most affected by age. We also found faster growth and decay in men than in women for most hippocampal subfields. Thus, while we introduced a new, fast and robust end-to-end segmentation tool, our neuroanatomical results concerning the lifespan trajectories of the hippocampal subfields reconcile previous conflicting results.}
}

About

Training code for the Hippocampal Segmentation Factory

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%