-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated model architecture to achieve 99.43% accuracy with 15,578 par…
…ameters
- Loading branch information
Showing
13 changed files
with
419 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
name: MNIST Model Test | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: [3.8] | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
|
||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Cache pip packages | ||
uses: actions/cache@v2 | ||
with: | ||
path: ~/.cache/pip | ||
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} | ||
restore-keys: | | ||
${{ runner.os }}-pip- | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements.txt | ||
pip install pytest pytest-cov | ||
- name: Run tests | ||
run: | | ||
pytest tests/ --cov=./ --cov-report=xml | ||
- name: Check model parameters | ||
run: | | ||
python -c " | ||
import torch | ||
from models.model import FastMNIST | ||
model = FastMNIST() | ||
total_params = sum(p.numel() for p in model.parameters()) | ||
assert total_params < 20000, f'Model has {total_params} parameters, should be less than 20000' | ||
print(f'Model parameter count check passed: {total_params} parameters') | ||
" | ||
- name: Upload coverage to Codecov | ||
uses: codecov/codecov-action@v1 | ||
with: | ||
file: ./coverage.xml | ||
flags: unittests | ||
name: codecov-umbrella | ||
fail_ci_if_error: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,65 @@ | ||
# MNIST Classification with PyTorch | ||
# MNIST 99.4% Accuracy Challenge | ||
|
||
This project implements a CNN model for MNIST digit classification with the following specifications: | ||
- Achieves 99.4% validation accuracy | ||
- Uses less than 20k parameters | ||
- Trains in less than 20 epochs | ||
- Implements Batch Normalization and Dropout | ||
- Uses Fully Connected layers | ||
This project achieves 99.43% accuracy on MNIST digit classification while maintaining less than 20,000 parameters. | ||
|
||
## Model Architecture | ||
- Input Layer | ||
- Convolutional layers with Batch Normalization | ||
- Max Pooling layers | ||
- Dropout for regularization | ||
- Fully Connected layers | ||
- Output layer (10 classes) | ||
- Input: 1 channel image (28x28) | ||
- Convolutional layers: | ||
* Conv1: 1 -> 8 channels | ||
* Conv2: 8 -> 16 channels | ||
* Conv3: 16 -> 32 channels | ||
- Each conv block includes: | ||
* 3x3 convolution with padding=1 | ||
* BatchNorm | ||
* ReLU activation | ||
* MaxPool2d | ||
- Fully connected layers: | ||
* FC1: 32 * 3 * 3 -> 32 | ||
* FC2: 32 -> 10 | ||
- Dropout (0.3) after conv3 and FC1 | ||
|
||
Total Parameters: 15,578 | ||
|
||
## Training Configuration | ||
- Epochs: 19 | ||
- Batch size: 128 | ||
- Optimizer: Adam | ||
* Learning rate: 0.001 | ||
* Weight decay: 1e-4 | ||
- Scheduler: OneCycleLR | ||
* Max learning rate: 0.003 | ||
* pct_start: 0.2 | ||
* div_factor: 10 | ||
* final_div_factor: 100 | ||
- Loss: CrossEntropyLoss | ||
|
||
## Data Augmentation | ||
- Random rotation (±10 degrees) | ||
- Random translation (±10%) | ||
- Normalization (mean=0.1307, std=0.3081) | ||
|
||
## Results | ||
- Best Test Accuracy: 99.43% | ||
- Parameters: 15,578 (under 20k limit) | ||
- Training Time: 19 epochs | ||
|
||
## Requirements | ||
- Python 3.8+ | ||
- PyTorch | ||
- torchvision | ||
- Python 3.8+ | ||
|
||
## Project Structure | ||
- `models/model.py`: Model architecture definition | ||
- `train.py`: Training script | ||
- `utils.py`: Utility functions | ||
- `tests/`: Test cases | ||
- `.github/workflows/`: CI/CD pipeline | ||
- numpy | ||
- matplotlib | ||
- tqdm | ||
|
||
## Usage | ||
```bash | ||
# Train the model | ||
python train.py | ||
``` | ||
|
||
# Run tests | ||
python -m pytest tests/test_model.py | ||
## Project Structure | ||
``` | ||
MNIST_99.4/ | ||
├── models/ | ||
│ └── model.py # Model architecture | ||
├── train.py # Training script | ||
└── README.md # Documentation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Metadata-Version: 2.1 | ||
Name: mnist_model | ||
Version: 0.1 | ||
Requires-Dist: torch>=1.9.0 | ||
Requires-Dist: torchvision>=0.10.0 | ||
Requires-Dist: pytest>=6.0 | ||
Requires-Dist: numpy>=1.19.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
README.md | ||
setup.py | ||
mnist_model.egg-info/PKG-INFO | ||
mnist_model.egg-info/SOURCES.txt | ||
mnist_model.egg-info/dependency_links.txt | ||
mnist_model.egg-info/requires.txt | ||
mnist_model.egg-info/top_level.txt | ||
tests/test_model.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
torch>=1.9.0 | ||
torchvision>=0.10.0 | ||
pytest>=6.0 | ||
numpy>=1.19.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,59 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class MNISTModel(nn.Module): | ||
class FastMNIST(nn.Module): | ||
""" | ||
A lightweight CNN for MNIST digit classification. | ||
Achieves >99.4% accuracy with less than 20k parameters. | ||
Architecture: | ||
- 3 Convolutional blocks (8->16->32 channels) | ||
- BatchNorm after each conv | ||
- MaxPool after each block | ||
- 2 FC layers (32 neurons in hidden layer) | ||
- Dropout (0.3) for regularization | ||
Total Parameters: 15,578 | ||
""" | ||
def __init__(self): | ||
super(MNISTModel, self).__init__() | ||
# First conv block - input channels: 1, output channels: 8 | ||
super(FastMNIST, self).__init__() | ||
# Simple and effective channel progression | ||
self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1) | ||
self.bn1 = nn.BatchNorm2d(8) | ||
|
||
# Second conv block - input channels: 8, output channels: 16 | ||
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1) | ||
self.bn2 = nn.BatchNorm2d(16) | ||
|
||
# Third conv block - input channels: 16, output channels: 20 | ||
self.conv3 = nn.Conv2d(16, 20, kernel_size=3, padding=1) | ||
self.bn3 = nn.BatchNorm2d(20) | ||
self.conv3 = nn.Conv2d(16, 32, kernel_size=3, padding=1) | ||
self.bn3 = nn.BatchNorm2d(32) | ||
|
||
self.dropout1 = nn.Dropout(0.25) | ||
self.dropout2 = nn.Dropout(0.25) | ||
# Efficient FC layers | ||
self.fc1 = nn.Linear(32 * 3 * 3, 32) | ||
self.fc2 = nn.Linear(32, 10) | ||
|
||
self.dropout = nn.Dropout(0.3) | ||
|
||
# Reduced FC layers | ||
self.fc1 = nn.Linear(20 * 3 * 3, 64) # Smaller FC layer | ||
self.fc2 = nn.Linear(64, 10) | ||
|
||
def forward(self, x): | ||
# First conv block | ||
x = F.relu(self.bn1(self.conv1(x))) # 28x28 -> 28x28 | ||
x = F.max_pool2d(x, 2) # 28x28 -> 14x14 | ||
|
||
# Second conv block | ||
x = F.relu(self.bn2(self.conv2(x))) # 14x14 -> 14x14 | ||
x = F.max_pool2d(x, 2) # 14x14 -> 7x7 | ||
x = self.dropout1(x) | ||
|
||
# Third conv block | ||
x = F.relu(self.bn3(self.conv3(x))) # 7x7 -> 7x7 | ||
x = F.max_pool2d(x, 2) # 7x7 -> 3x3 | ||
x = self.dropout2(x) | ||
|
||
# Flatten and FC layers | ||
x = x.view(x.size(0), -1) # Flatten: 20 * 3 * 3 = 180 | ||
x = F.relu(self.fc1(x)) | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = F.relu(x) | ||
x = F.max_pool2d(x, 2) | ||
|
||
x = self.conv2(x) | ||
x = self.bn2(x) | ||
x = F.relu(x) | ||
x = F.max_pool2d(x, 2) | ||
|
||
x = self.conv3(x) | ||
x = self.bn3(x) | ||
x = F.relu(x) | ||
x = F.max_pool2d(x, 2) | ||
x = self.dropout(x) | ||
|
||
x = x.view(-1, 32 * 3 * 3) | ||
x = self.fc1(x) | ||
x = F.relu(x) | ||
x = self.dropout(x) | ||
x = self.fc2(x) | ||
return x | ||
return F.log_softmax(x, dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
torch>=2.0.0 | ||
torchvision>=0.15.0 | ||
numpy>=1.21.0 | ||
matplotlib>=3.5.0 | ||
tqdm>=4.65.0 | ||
pytest>=7.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name="mnist_model", | ||
version="0.1", | ||
packages=find_packages(), | ||
install_requires=[ | ||
'torch>=1.9.0', | ||
'torchvision>=0.10.0', | ||
'pytest>=6.0', | ||
'numpy>=1.19.0', | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,67 @@ | ||
from utils import count_parameters, has_batch_norm, has_dropout, has_fully_connected | ||
from models.model import MNISTModel | ||
import pytest | ||
from models.model import FastMNIST | ||
|
||
model = MNISTModel() | ||
@pytest.fixture | ||
def model(): | ||
return FastMNIST() | ||
|
||
# Unit test checks | ||
assert count_parameters(model) <= 20000, "Model exceeds 20k parameters!" | ||
assert has_batch_norm(model), "Batch Normalization not used in the model!" | ||
assert has_dropout(model), "Dropout not used in the model!" | ||
assert has_fully_connected(model), "Fully Connected layers not used in the model!" | ||
def test_model_structure(model): | ||
"""Test the model structure and parameter count.""" | ||
total_params = sum(p.numel() for p in model.parameters()) | ||
assert total_params < 20000, f"Model has {total_params} parameters, should be < 20000" | ||
|
||
def test_forward_pass(model): | ||
"""Test the forward pass with a batch of data.""" | ||
batch_size = 32 | ||
x = torch.randn(batch_size, 1, 28, 28) | ||
output = model(x) | ||
|
||
# Check output shape | ||
assert output.shape == (batch_size, 10), f"Expected shape (32, 10), got {output.shape}" | ||
|
||
# Check output is probability distribution | ||
assert torch.allclose(torch.exp(output).sum(dim=1), torch.ones(batch_size), atol=1e-6) | ||
|
||
def test_batch_norm_layers(model): | ||
"""Test that batch normalization layers are present.""" | ||
assert hasattr(model, 'bn1'), "Model missing bn1 layer" | ||
assert hasattr(model, 'bn2'), "Model missing bn2 layer" | ||
assert hasattr(model, 'bn3'), "Model missing bn3 layer" | ||
|
||
def test_dropout_layer(model): | ||
"""Test that dropout layer is present with correct rate.""" | ||
assert hasattr(model, 'dropout'), "Model missing dropout layer" | ||
assert model.dropout.p == 0.4, f"Expected dropout rate 0.4, got {model.dropout.p}" | ||
|
||
def test_conv_layers(model): | ||
"""Test convolutional layers configuration.""" | ||
# Test conv1 | ||
assert model.conv1.in_channels == 1, "Conv1 should have 1 input channel" | ||
assert model.conv1.out_channels == 16, "Conv1 should have 16 output channels" | ||
|
||
# Test conv2 | ||
assert model.conv2.in_channels == 16, "Conv2 should have 16 input channels" | ||
assert model.conv2.out_channels == 32, "Conv2 should have 32 output channels" | ||
|
||
# Test conv3 | ||
assert model.conv3.in_channels == 32, "Conv3 should have 32 input channels" | ||
assert model.conv3.out_channels == 32, "Conv3 should have 32 output channels" | ||
|
||
def test_fc_layers(model): | ||
"""Test fully connected layers configuration.""" | ||
assert model.fc1.in_features == 32 * 3 * 3, "FC1 input features incorrect" | ||
assert model.fc1.out_features == 128, "FC1 output features should be 128" | ||
assert model.fc2.in_features == 128, "FC2 input features should be 128" | ||
assert model.fc2.out_features == 10, "FC2 output features should be 10" | ||
|
||
def test_gradient_flow(model): | ||
"""Test that gradients can flow through the model.""" | ||
x = torch.randn(1, 1, 28, 28, requires_grad=True) | ||
output = model(x) | ||
loss = output.sum() | ||
loss.backward() | ||
|
||
# Check that all parameters have gradients | ||
for name, param in model.named_parameters(): | ||
assert param.grad is not None, f"No gradient for {name}" | ||
assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}" |
Oops, something went wrong.