Multi-head Attention-based Deep Multiple Instance Learning
The attention component used in ABMIL(Left), where only one attention module is utilized, and in MAD-MIL(Right), where multiple attention modules are incorporated.
This is the PyTorch implementation of the MAD-MIL, which is based on CLAM and WSI-finetuning.
Data Preparation
For the preprocessing of TUPAC16 and TCGA datasets, we adhere to CLAM's steps, incorporating features extracted from non-overlapping 256×256 patches at 20× magnification. We share the extracted features through the following links:
- TUPAC16
https://zenodo.org/records/10563985/files/TUPAC.zip?download=1
- BRCA
https://zenodo.org/records/10563985/files/BRCA.zip?download=1
- LUNG
https://zenodo.org/records/10563985/files/LUNG.zip?download=1
- KIDNEY
https://zenodo.org/records/10563985/files/KIDNEY.zip?download=1
Training
The training can be done for different models and datasets with proper arguments like dataset_dir, task_name, model_name, lr, and reg. This is an example of training MAD-MIL on TUPAC16.
python train.py --data_root_dir feat-directory --lr 1e-4 --reg 1e-6 --seed 2021 --k 5 --k_end 5 --split_dir task_tupac16 --model_type abmil_multihead --task task_1_tumor_vs_normal --csv_path ./dataset_csv/tupac16.csv --exp_code MAD_Five_reg_1e-6 --n 5
We sweep the weight-decay value among {1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8} and choose the optimal value based on the validation loss.
Examples
We present attention heatmaps in the following figure to assess the interpretability of the methods.
Attention heatmaps generated by different models for a Tumor slide selected from the TUPAC dataset. Top row: from left to right, the attention heatmap produced by the ABMIL and CLAM-MB. Bottom rows: from left to right, the attention heatmap generated by the MAD-MIL/3-head-1, MAD- MIL/3-head-2, MAD-MIL/3-head-3.
Reference
Please consider citing the following paper if you find our work useful for your project.
@misc{,
title={Multi-head Attention-based Deep Multiple Instance Learning},
author={},
year={2024},
eprint={},
archivePrefix={},
primaryClass={cs.CV}
}