Vision Transformer sets the benchmark for image representation. However, unique challenges arise in certain imaging fields such as microscopy and satellite imaging:
- Unlike RGB images, images in these domains often contain multiple channels, each potentially carrying semantically distinct and independent information.
- Not all input channels may densely be available at training or test time, necessitating a model that performs robustly under these conditions.
In response to these challenges, we introduce ChannelViT and Hierarchical Channel Sampling.
- ChannelViT constructs patch tokens independently from each input channel. It employs a set of learnable channel embeddings to encode channel-specific information, which are then added to the patch tokens in a manner akin to positional embeddings. This modification enables ChannelViT to perform cross-channel and cross-position reasoning, a critical feature for multi-channel imaging.
- Hierarchical Channel Sampling (HCS) employs a two-step sampling procedure to simulate test time channel unavailability during training. Unlike channel dropout, where each channel is dropped independently and biases a certain number of selected channels, the two-stage sampling procedure ensures HCS covers channel combinations with varying numbers of channels uniformly. This results in a consistent and significant improvement in robustness.
Should you have any questions or require further assistance, please do not hesitate to create an issue. We are here to provide support. 🤗
This project is developed based on PyTorch 2.0 and PyTorch-Lightning 2.0.1. We use conda to manage the Python environment. You can setup the enviroment by running
git clone [email protected]:insitro/ChannelViT.git
cd ChannelViT
conda env create -f environment.yml
conda activate channelvit
You can then install channelvit through pip.
pip install git+https://github.com/insitro/ChannelViT.git
This section outlines the steps to reproduce our training and evaluation pipelines using JUMP-CP. The preprocessed JUMP-CP data used in this example was previously released in our work titled, "Contextual Vision Transformers for Robust Representation Learning" (insitro/ContextViT).
Before initiating the training of any models, it is beneficial to visualize the correlations among the input channels. The script provided below will load the JUMP-CP data and compute the channel correlations derived from the original cell painting images, without any normalization.
python channelvit/main/main_correlation.py \
trainer.devices=8 \
trainer.max_epochs=100 \
data@train_data=jumpcp \
data@val_data_dict=[jumpcp_test] \
transformations@train_transformations=cell \
transformations@val_transformations=cell
val_transformations.normalization.mean=[0,0,0,0,0,0,0,0] \
val_transformations.normalization.std=[1,1,1,1,1,1,1,1]
Next, we will train a ViT-S/16 model without using HCS. For managing our experiment configuration, we utilize hydra. The script provided below will load its corresponding main configuration file, channelvit/config/main_supervised.yaml
, along with any command line overrides. It trains the ViT-S/16 model to minimize the cross-entropy loss on the JUMP-CP training data over the course of 100 epochs. The process requires 8 GPUs and operates with a batch size of 32 per GPU.
python channelvit/main/main_supervised.py \
trainer.devices=8 \
trainer.max_epochs=100 \
meta_arch/backbone=vit_small \
meta_arch.backbone.args.in_chans=8 \
meta_arch.target='label' \
meta_arch.num_classes=161 \
data@train_data=jumpcp \
data@val_data_dict=[jumpcp_val,jumpcp_test] \
train_data.jumpcp.loader.batch_size=32 \
transformations@train_transformations=cell \
transformations@val_transformations=cell
Given that each cell image in JUMP-CP contains 8 channels, we override the input channels to 8. Throughout the training, we save the snapshots in the ./snapshots/
directory. You can alter this path by overriding the value of trainer.default_root_dir
.
To train the ViT-S/16 using hierarchical channel sampling, simply override the meta_arch/backbone setting to hcs_vit_small. With this setting, the Hierarchical Channel Sampling (HCS) will perform the following actions for each batch:
- Randomly determine the number of channels to be used for the current batch.
- Randomly select the combinations of channels.
- Scale the selected channels by a factor, which is calculated as the ratio of the total number of channels to the number of selected channels. The script below provides an example of how to train the ViT-S/8 model using HCS.
python channelvit/main/main_supervised.py \
trainer.devices=8 \
trainer.max_epochs=100 \
meta_arch/backbone=hcs_vit_small \
meta_arch.backbone.args.in_chans=8 \
meta_arch.backbone.args.patch_size=8 \
meta_arch.target='label' \
meta_arch.num_classes=161 \
data@train_data=jumpcp \
data@val_data_dict=[jumpcp_val,jumpcp_test] \
train_data.jumpcp.loader.batch_size=32 \
transformations@train_transformations=cell \
transformations@val_transformations=cell
The script below will enumerate all possible channel combinations and evaluate the corresponding testing accuracy of the trained model (stored at PATH_TO_CKPT
). In this case, we set transformation_mask=True
because ViT assumes the same number of input channels for the patch embedding layer. The unselected channels will be filled with zeros, and the selected channels will be scaled by the ratio of the total number of channels to the number of selected channels.
python channelvit/main/main_supervised_evalall.py \
trainer.devices=8 \
transformation_mask=True \
data@val_data=jumpcp_test \
val_data.jumpcp_test.loader.batch_size=32 \
val_data.jumpcp_test.args.channels=[0,1,2,3,4,5,6,7] \
transformations=cell \
checkpoint=${PATH_TO_CKPT}
Training ChannelViT follows a similar process to training ViT. All you need to do is override the meta_arch/backbone setting.
python channelvit/main/main_supervised.py \
trainer.devices=8 \
trainer.max_epochs=100 \
meta_arch/backbone=channelvit_small \
meta_arch.backbone.args.in_chans=8 \
meta_arch.target='label' \
meta_arch.num_classes=161 \
data@train_data=jumpcp \
data@val_data_dict=[jumpcp_val,jumpcp_test] \
train_data.jumpcp.loader.batch_size=32 \
transformations@train_transformations=cell \
transformations@val_transformations=cell
Given that the patch token of ChannelViT originates from a single channel, applying HCS with ChannelViT essentially results in a shorter input patch sequence for the model. Unlike ViT, where we need to perform input rescaling to maintain smooth input distributions when different channels are used, with ChannelViT we can simply exclude the patches corresponding to the unselected channels from the input sequence.
The following script provides an example of how to train the ChannelViT-S/8 model using HCS.
python channelvit/main/main_supervised.py \
trainer.devices=8 \
trainer.max_epochs=100 \
meta_arch/backbone=hcs_channelvit_small \
meta_arch.backbone.args.in_chans=8 \
meta_arch.target='label' \
meta_arch.num_classes=161 \
data@train_data=jumpcp \
data@val_data_dict=[jumpcp_val,jumpcp_test] \
train_data.jumpcp.loader.batch_size=32 \
transformations@train_transformations=cell \
transformations@val_transformations=cell
The performance of the trained models can be evaluated over all channel combinations using the main_supervised_evalall.py
script, similar to the evaluation process for ViT. However, there is one key difference. As we consider input channels as the sequence length of the transformer encoders, we set transformation_mask=True
. This means we don't need to reset pixels corresponding to these channels to zero.
python channelvit/main/main_supervised_evalall.py \
trainer.devices=8 \
transformation_mask=True \
data@val_data=jumpcp_test \
val_data.jumpcp_test.loader.batch_size=32 \
val_data.jumpcp_test.args.channels=[0,1,2,3,4,5,6,7] \
transformations=cell \
checkpoint=${PATH_TO_CKPT}
To use the ImageNet 2012 dataset, you must first download it from the official release. We have provided a dataset class in channelvit/data/imagenet.py
for loading the ImageNet data. This class requires two dataframes (one for the training split and one for the validation split) with keys path (the path to the image) and label (the image's label). You can specify the path to these dataframes in the Hydra configurations (channelvit/config/data/imagenet.yaml
and channelvit/config/data/imagenet_valid.yaml
).
We suggest using the official WILDS package to download the Camelyon17-WILDS dataset.
pip install wilds
# This will download the labeled data under wilds_base_path/
python wilds/download_datasets.py --root_dir wilds_base_path
After successfully downloading the data, you can set the base path of the WILDS dataset in the Hydra configuration (for example, channelvit/config/data/camelyon_train.yaml
).
To download the So2Sat dataset, follow the instructions in the official release. We use the first version (hard split) and the third version (random split). We have provided a dataset class in channelvit/data/so2sat.py
for loading the So2Sat data. This class requires a dataframe with keys path (the path to the image) and label (the image's label). You can specify the path to this dataframe in the Hydra configuration (for example, channelvit/config/data/so2sat_hard.yaml
).
If our work contributes to your research, we would greatly appreciate a citation.
@article{bao2023channel,
title={Channel Vision Transformers: An Image Is Worth C x 16 x 16 Words},
author={Bao, Yujia and Sivanandan, Srinivasan and Karaletsos, Theofanis},
journal={arXiv preprint arXiv:2309.16108},
year={2023}
}