Deep learning models for CIFAR10 implemented in pytorch. CNN, RNN and Alexnet implemented up till now. Useful for testing the performance of different model architectures. Can run both on CPU only and GPU.
Python environment with pytorch, torchvision and scikit-learn is required.
Download the python version of the CIFAR10 dataset from the official website: https://www.cs.toronto.edu/~kriz/cifar.html. It contains an archive with pickle files. In load_data.py one can find functions to load the data from the pickle files into a pytorch Dataset.
Use the code train_cnn_model.py. Some architectures are present in custom_models.py. To implement a new architecture one must create a class inheriting nn.Module and implementing __init__ and forward methods. Accuracy is evaluated with confusion matrix and percentage of correct hits.
The same as CNN, but the code is train_rnn_model.py.
Some prebuilt model architectures can be found here: https://github.com/pytorch/vision/tree/master/torchvision/models.
The code in train_alexnet_model.py implements the AlexNet architecture from the link above.