Skip to content

Pytorch models implemented on CIFAR10

Notifications You must be signed in to change notification settings

cecP/cifar10-pytorch

Repository files navigation

Pytorch models implemented on CIFAR10

Deep learning models for CIFAR10 implemented in pytorch. CNN, RNN and Alexnet implemented up till now. Useful for testing the performance of different model architectures. Can run both on CPU only and GPU.

Prerequisites

Python environment with pytorch, torchvision and scikit-learn is required.

Getting the Data

Download the python version of the CIFAR10 dataset from the official website: https://www.cs.toronto.edu/~kriz/cifar.html. It contains an archive with pickle files. In load_data.py one can find functions to load the data from the pickle files into a pytorch Dataset.



Models

CNN models

Use the code train_cnn_model.py. Some architectures are present in custom_models.py. To implement a new architecture one must create a class inheriting nn.Module and implementing __init__ and forward methods. Accuracy is evaluated with confusion matrix and percentage of correct hits.

RNN models

The same as CNN, but the code is train_rnn_model.py.

Prebuilt pytorch models

Some prebuilt model architectures can be found here: https://github.com/pytorch/vision/tree/master/torchvision/models.

The code in train_alexnet_model.py implements the AlexNet architecture from the link above.

About

Pytorch models implemented on CIFAR10

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages