-
Notifications
You must be signed in to change notification settings - Fork 123
/
Copy pathcifar10.py
138 lines (115 loc) · 5.11 KB
/
cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""CIFAR10 example for cnn_finetune.
Based on:
- https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py
- https://github.com/pytorch/examples/blob/master/mnist/main.py
"""
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from cnn_finetune import make_model
parser = argparse.ArgumentParser(description='cnn_finetune cifar 10 example')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
help='input batch size for testing (default: 64)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--model-name', type=str, default='resnet50', metavar='M',
help='model name (default: resnet50)')
parser.add_argument('--dropout-p', type=float, default=0.2, metavar='D',
help='Dropout probability (default: 0.2)')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
def train(model, epoch, optimizer, train_loader, criterion=nn.CrossEntropyLoss()):
total_loss = 0
total_size = 0
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
total_loss += loss.item()
total_size += data.size(0)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage loss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), total_loss / total_size))
def test(model, test_loader, criterion=nn.CrossEntropyLoss()):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
'''Main function to run code in this script'''
model_name = args.model_name
if model_name == 'alexnet':
raise ValueError('The input size of the CIFAR-10 data set (32x32) is too small for AlexNet')
classes = (
'plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
)
model = make_model(
model_name,
pretrained=True,
num_classes=len(classes),
dropout_p=args.dropout_p,
input_size=(32, 32) if model_name.startswith(('vgg', 'squeezenet')) else None,
)
model = model.to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=model.original_model_info.mean,
std=model.original_model_info.std),
])
train_set = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=args.batch_size, shuffle=True, num_workers=2
)
test_set = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(
test_set, args.test_batch_size, shuffle=False, num_workers=2
)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
# Use exponential decay for fine-tuning optimizer
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.975)
# Train
for epoch in range(1, args.epochs + 1):
# Decay Learning Rate
scheduler.step(epoch)
train(model, epoch, optimizer, train_loader)
test(model, test_loader)
if __name__ == '__main__':
main()