-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Stephen
committed
Nov 13, 2021
0 parents
commit 863fc7c
Showing
13 changed files
with
517 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[flake8] | ||
exclude = .venv | ||
ignore = E501, W503, E226 | ||
|
||
; E501: Line too long | ||
; W503: Line break occurred before binary operator | ||
; E226: Missing white space around arithmetic operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# vscode | ||
/.vscode | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# Others | ||
*.pt | ||
data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
FROM python:3.7-slim | ||
|
||
COPY setup.py setup.py | ||
COPY requirements.txt requirements.txt | ||
COPY Makefile Makefile | ||
COPY src src | ||
COPY streamlit streamlit | ||
|
||
RUN make install | ||
|
||
EXPOSE 8080 | ||
|
||
CMD ["streamlit", "streamlit/st_app.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
env: | ||
virtualenv ~/.venv &&\ | ||
source ~/.venv/bin/activate | ||
|
||
|
||
install: | ||
pip install --upgrade pip &&\ | ||
pip install -r requirements.txt &&\ | ||
pip install -e . | ||
|
||
|
||
style: | ||
black . | ||
flake8 . | ||
|
||
|
||
test: | ||
pytest | ||
|
||
|
||
deploy: | ||
git push heroku master | ||
|
||
|
||
.PHONY: streamlit | ||
streamlit: | ||
streamlit run streamlit/st_app.py | ||
|
||
|
||
all: env install style test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
build: | ||
docker: | ||
web: Dockerfile | ||
run: | ||
web: bundle exec puma -C config/puma.rb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[tool.pytest.ini_options] | ||
testpaths = ["tests"] | ||
python_files = "test_*.py" | ||
filterwarnings = [ | ||
"error", | ||
"ignore::DeprecationWarning:", | ||
# note the use of single quote below to denote "raw" strings in TOML | ||
'ignore:Using or importing the ABCs:DeprecationWarning', | ||
] | ||
|
||
|
||
[tool.black] | ||
line-length = 100 | ||
include = '\.pyi?$' | ||
exclude = ''' | ||
/( | ||
\.eggs # exclude a few common directories in the root of the project | ||
| \.git | ||
| \.hg | ||
| \.mypy_cache | ||
| \.tox | ||
| \.venv | ||
| _build | ||
| buck-out | ||
| build | ||
| dist | ||
)/ | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
torch==1.6.0 | ||
torchvision==0.7.0 | ||
black==20.8b1 | ||
flake8==3.8.4 | ||
pytest==6.2.5 | ||
streamlit==0.75.0 | ||
streamlit_drawable_canvas==0.8.0 | ||
pillow==7.2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from setuptools import setup | ||
|
||
setup(name="mnist", packages=["src"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from __future__ import print_function | ||
import argparse | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torchvision import datasets, transforms | ||
from torch.optim.lr_scheduler import StepLR | ||
from src.net import Net | ||
|
||
|
||
def get_loader(batch_size, train_mode, download=True, **kwargs): | ||
transform = transforms.Compose( | ||
[ | ||
transforms.RandomAffine(degrees=(-5, 5), translate=(0.1, 0.3), scale=(0.75, 1.2)), | ||
transforms.ToTensor(), | ||
] | ||
) | ||
dataset = datasets.MNIST("./data", train=train_mode, download=download, transform=transform) | ||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, **kwargs) | ||
|
||
|
||
def train(args, model, device, train_loader, optimizer, epoch): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % args.log_interval == 0: | ||
print( | ||
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( | ||
epoch, | ||
batch_idx * len(data), | ||
len(train_loader.dataset), | ||
100.0 * batch_idx / len(train_loader), | ||
loss.item(), | ||
) | ||
) | ||
if args.dry_run: | ||
break | ||
|
||
|
||
def tst(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss | ||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
print( | ||
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( | ||
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) | ||
) | ||
) | ||
|
||
|
||
def main(): | ||
# Training settings | ||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example") | ||
parser.add_argument( | ||
"--batch-size", | ||
type=int, | ||
default=64, | ||
metavar="N", | ||
help="input batch size for training (default: 64)", | ||
) | ||
parser.add_argument( | ||
"--test-batch-size", | ||
type=int, | ||
default=1000, | ||
metavar="N", | ||
help="input batch size for testing (default: 1000)", | ||
) | ||
parser.add_argument( | ||
"--epochs", | ||
type=int, | ||
default=10, | ||
metavar="N", | ||
help="number of epochs to train (default: 10)", | ||
) | ||
parser.add_argument( | ||
"--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)" | ||
) | ||
parser.add_argument( | ||
"--gamma", | ||
type=float, | ||
default=0.7, | ||
metavar="M", | ||
help="Learning rate step gamma (default: 0.7)", | ||
) | ||
parser.add_argument( | ||
"--no-cuda", action="store_true", default=False, help="disables CUDA training" | ||
) | ||
parser.add_argument( | ||
"--dry-run", action="store_true", default=False, help="quickly check a single pass" | ||
) | ||
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") | ||
parser.add_argument( | ||
"--log-interval", | ||
type=int, | ||
default=10, | ||
metavar="N", | ||
help="how many batches to wait before logging training status", | ||
) | ||
parser.add_argument( | ||
"--save-model", action="store_true", default=False, help="For saving the current Model" | ||
) | ||
args = parser.parse_args() | ||
use_cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
torch.manual_seed(args.seed) | ||
|
||
device = torch.device("cuda" if use_cuda else "cpu") | ||
|
||
if use_cuda: | ||
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} | ||
|
||
train_loader = get_loader(batch_size=args.batch_size, train_mode=True, **cuda_kwargs) | ||
test_loader = get_loader(batch_size=args.test_batch_size, train_mode=False, **cuda_kwargs) | ||
|
||
model = Net().to(device) | ||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr) | ||
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) | ||
|
||
max_epochs = 2 if args.dry_run else args.epochs + 1 | ||
for epoch in range(1, max_epochs): | ||
train(args, model, device, train_loader, optimizer, epoch) | ||
tst(model, device, test_loader) | ||
scheduler.step() | ||
|
||
if args.save_model: | ||
torch.save(model.state_dict(), "mnist_cnn.pt") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.