From 863fc7c828a049dc2855c8fa0c6ef2d018100605 Mon Sep 17 00:00:00 2001 From: Stephen Date: Sat, 13 Nov 2021 21:45:55 +0800 Subject: [PATCH] first commit --- .flake8 | 7 ++ .gitignore | 145 +++++++++++++++++++++++++++++++++++++++ Dockerfile | 13 ++++ Makefile | 30 ++++++++ heroku.yml | 5 ++ pyproject.toml | 28 ++++++++ requirements.txt | 8 +++ setup.py | 3 + src/main.py | 145 +++++++++++++++++++++++++++++++++++++++ src/net.py | 29 ++++++++ src/predict.py | 17 +++++ streamlit/st_app.py | 45 ++++++++++++ tests/test_everything.py | 42 ++++++++++++ 13 files changed, 517 insertions(+) create mode 100644 .flake8 create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 heroku.yml create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 src/main.py create mode 100644 src/net.py create mode 100644 src/predict.py create mode 100644 streamlit/st_app.py create mode 100644 tests/test_everything.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..7ef62d0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,7 @@ +[flake8] +exclude = .venv +ignore = E501, W503, E226 + +; E501: Line too long +; W503: Line break occurred before binary operator +; E226: Missing white space around arithmetic operator \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5b7e33a --- /dev/null +++ b/.gitignore @@ -0,0 +1,145 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# vscode +/.vscode + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Others +*.pt +data/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c524aca --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.7-slim + +COPY setup.py setup.py +COPY requirements.txt requirements.txt +COPY Makefile Makefile +COPY src src +COPY streamlit streamlit + +RUN make install + +EXPOSE 8080 + +CMD ["streamlit", "streamlit/st_app.py"] \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4fb713a --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +env: + virtualenv ~/.venv &&\ + source ~/.venv/bin/activate + + +install: + pip install --upgrade pip &&\ + pip install -r requirements.txt &&\ + pip install -e . + + +style: + black . + flake8 . + + +test: + pytest + + +deploy: + git push heroku master + + +.PHONY: streamlit +streamlit: + streamlit run streamlit/st_app.py + + +all: env install style test \ No newline at end of file diff --git a/heroku.yml b/heroku.yml new file mode 100644 index 0000000..0820636 --- /dev/null +++ b/heroku.yml @@ -0,0 +1,5 @@ +build: + docker: + web: Dockerfile +run: + web: bundle exec puma -C config/puma.rb \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..667d839 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" +filterwarnings = [ + "error", + "ignore::DeprecationWarning:", + # note the use of single quote below to denote "raw" strings in TOML + 'ignore:Using or importing the ABCs:DeprecationWarning', +] + + +[tool.black] +line-length = 100 +include = '\.pyi?$' +exclude = ''' +/( + \.eggs # exclude a few common directories in the root of the project + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ +''' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..21735fc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch==1.6.0 +torchvision==0.7.0 +black==20.8b1 +flake8==3.8.4 +pytest==6.2.5 +streamlit==0.75.0 +streamlit_drawable_canvas==0.8.0 +pillow==7.2.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..fb927b9 --- /dev/null +++ b/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup(name="mnist", packages=["src"]) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..f44699f --- /dev/null +++ b/src/main.py @@ -0,0 +1,145 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR +from src.net import Net + + +def get_loader(batch_size, train_mode, download=True, **kwargs): + transform = transforms.Compose( + [ + transforms.RandomAffine(degrees=(-5, 5), translate=(0.1, 0.3), scale=(0.75, 1.2)), + transforms.ToTensor(), + ] + ) + dataset = datasets.MNIST("./data", train=train_mode, download=download, transform=transform) + return torch.utils.data.DataLoader(dataset, batch_size=batch_size, **kwargs) + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if args.dry_run: + break + + +def tst(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) + ) + ) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)" + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--dry-run", action="store_true", default=False, help="quickly check a single pass" + ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", action="store_true", default=False, help="For saving the current Model" + ) + args = parser.parse_args() + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + if use_cuda: + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + + train_loader = get_loader(batch_size=args.batch_size, train_mode=True, **cuda_kwargs) + test_loader = get_loader(batch_size=args.test_batch_size, train_mode=False, **cuda_kwargs) + + model = Net().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + + max_epochs = 2 if args.dry_run else args.epochs + 1 + for epoch in range(1, max_epochs): + train(args, model, device, train_loader, optimizer, epoch) + tst(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == "__main__": + main() diff --git a/src/net.py b/src/net.py new file mode 100644 index 0000000..5718930 --- /dev/null +++ b/src/net.py @@ -0,0 +1,29 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output diff --git a/src/predict.py b/src/predict.py new file mode 100644 index 0000000..3218bad --- /dev/null +++ b/src/predict.py @@ -0,0 +1,17 @@ +import numpy as np +import torch +from torchvision import transforms +from src.net import Net + + +def predict(img_numpy, ckpt_path): + transform = transforms.Compose([transforms.ToTensor()]) + + img_torch = transform(img_numpy.astype(np.float32)) + img_torch = img_torch.unsqueeze(dim=0) + + net = Net() + net.eval() + net.load_state_dict(torch.load(ckpt_path)) + log_preds = net(img_torch) + return log_preds diff --git a/streamlit/st_app.py b/streamlit/st_app.py new file mode 100644 index 0000000..084afaf --- /dev/null +++ b/streamlit/st_app.py @@ -0,0 +1,45 @@ +import cv2 +import numpy as np +import streamlit as st +from streamlit_drawable_canvas import st_canvas +from src.predict import predict + + +MODEL_INPUT_SIZE = 28 +CANVAS_SIZE = MODEL_INPUT_SIZE * 8 +CHECKPOINT_PATH = "mnist_cnn.pt" + + +def main(): + st.write("Draw something here") + canvas_res = st_canvas( + stroke_width=20, + stroke_color="white", + background_color="black", + width=CANVAS_SIZE, + height=CANVAS_SIZE, + drawing_mode="freedraw", + key="canvas", + display_toolbar=True, + ) + + if canvas_res.image_data is not None: + # Scale down image to the model input size + img = cv2.resize( + canvas_res.image_data.astype("uint8"), (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE) + ) + # Rescaled image upwards to show + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img_rescaled = cv2.resize(img, (CANVAS_SIZE, CANVAS_SIZE), interpolation=cv2.INTER_NEAREST) + + st.write("Downscaled model input:") + st.image(img_rescaled) + + pred = predict(np.array(img), CHECKPOINT_PATH).detach().numpy() + st.write( + f"The predicted digit is {pred.argmax()}. The model is {np.exp(pred.max()) * 100:.1f}% sure." + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_everything.py b/tests/test_everything.py new file mode 100644 index 0000000..3473de7 --- /dev/null +++ b/tests/test_everything.py @@ -0,0 +1,42 @@ +import pytest +import torch +import numpy as np +from src.net import Net +from src.main import train, tst, get_loader +from src.predict import predict + + +@pytest.fixture +def net(): + return Net() + + +def test_net(net): + BATCH_SIZE = 2 + x = torch.zeros(BATCH_SIZE, 1, 28, 28) + out = net(x) + assert out.shape == (BATCH_SIZE, 10) + + +class Args: + dry_run = True + log_interval = 100 + + +def test_train_loop(net): + device = "cpu" + train_loader = get_loader(8, True) + optimizer = torch.optim.Adadelta(net.parameters(), lr=0.001) + epoch = 1 + train(Args, net, device, train_loader, optimizer, epoch) + + +def test_testing_loop(net): + device = "cpu" + test_loader = get_loader(8, False) + tst(net, device, test_loader) + + +def test_predict(): + x = np.zeros((28, 28)) + predict(x, "mnist_cnn_final.pt")