Skip to content

Simplified Pytorch implementation of Vision Transformer (ViT) for MNIST dataset.

Notifications You must be signed in to change notification settings

gejinchen/PyTorch-Vision-Transformer-ViT-MNIST-CIFAR10

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

93 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision Transformer-MNIST-CIFAR10

Simplified Scratch Pytorch implementation of Vision Transformer (ViT) with detailed steps (Refer to model.py).

  • Scaled-down version of the original ViT architecture from An Image is Worth 16X16 Words.
  • Has only 184k parameters (Original ViT-Base has 86 million).
  • Works with small datasets by using a smaller patch size of 4.
  • Supported datasets: MNIST, FashionMNIST, SVHN, and CIFAR10.



Run commands (also available in scripts.sh):

Dataset Run command Test Acc
MNIST python main.py --dataset mnist --epochs 100 99.4
Fashion MNIST python main.py --dataset fmnist 92.5
SVHN python main.py --dataset svhn --n_channels 3 --image_size 32 91.5
CIFAR-10 python main.py --dataset cifar10 --n_channels 3 --image_size 32 77.0

The default path for downloading the dataset is "./data" and can be changed using the --data_path argument.



Transformer Config:

Config MNIST and FMNIST SVHN and CIFAR10
Input Size 1 X 28 X 28 3 X 32 X 32
Patch Size 4 4
Sequence Length 7*7 = 49 8*8 = 64
Embedding Size 64 64
Num of Layers 6 6
Num of Heads 4 4
Forward Multiplier 2 2

About

Simplified Pytorch implementation of Vision Transformer (ViT) for MNIST dataset.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.9%
  • Shell 1.1%