forked from rosinality/vq-vae-2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpixelsnail_mnist.py
executable file
·60 lines (40 loc) · 1.48 KB
/
pixelsnail_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm
from pixelsnail import PixelSNAIL
def train(epoch, loader, model, optimizer, device):
loader = tqdm(loader)
criterion = nn.CrossEntropyLoss()
for i, (img, label) in enumerate(loader):
model.zero_grad()
img = img.to(device)
out = model(img)
loss = criterion(out, img)
loss.backward()
optimizer.step()
_, pred = out.max(1)
correct = (pred == img).float()
accuracy = correct.sum() / img.numel()
loader.set_description(
(f'epoch: {epoch + 1}; loss: {loss.item():.5f}; ' f'acc: {accuracy:.5f}')
)
class PixelTransform:
def __init__(self):
pass
def __call__(self, input):
ar = np.array(input)
return torch.from_numpy(ar).long()
if __name__ == '__main__':
device = 'cuda'
epoch = 10
dataset = datasets.MNIST('.', transform=PixelTransform(), download=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
model = PixelSNAIL([28, 28], 256, 128, 5, 2, 4, 128)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for i in range(10):
train(i, loader, model, optimizer, device)
torch.save(model.state_dict(), f'checkpoint/mnist_{str(i + 1).zfill(3)}.pt')