Simplified Scratch Pytorch implementation of Vision Transformer (ViT) with detailed steps (Refer to model.py).
- Scaled-down version of the original ViT architecture from An Image is Worth 16X16 Words.
- Has only 184k parameters (Original ViT-Base has 86 million).
- Works with small datasets by using a smaller patch size of 4.
- Supported datasets: MNIST, FashionMNIST, SVHN, and CIFAR10.
Run commands (also available in scripts.sh):
Dataset | Run command | Test Acc |
---|---|---|
MNIST | python main.py --dataset mnist --epochs 100 | 99.4 |
Fashion MNIST | python main.py --dataset fmnist | 92.5 |
SVHN | python main.py --dataset svhn --n_channels 3 --image_size 32 | 91.5 |
CIFAR-10 | python main.py --dataset cifar10 --n_channels 3 --image_size 32 | 77.0 |
The default path for downloading the dataset is "./data" and can be changed using the --data_path argument.
Transformer Config:
Config | MNIST and FMNIST | SVHN and CIFAR10 |
---|---|---|
Input Size | 1 X 28 X 28 | 3 X 32 X 32 |
Patch Size | 4 | 4 |
Sequence Length | 7*7 = 49 | 8*8 = 64 |
Embedding Size | 64 | 64 |
Num of Layers | 6 | 6 |
Num of Heads | 4 | 4 |
Forward Multiplier | 2 | 2 |