-
Notifications
You must be signed in to change notification settings - Fork 0
/
flux_cnn.jl
71 lines (52 loc) · 1.83 KB
/
flux_cnn.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Super simple convolutional neural network with Julia Flux
# to predict hand-written digits with MNIST library
using Flux, Plots, ProgressMeter, MLDatasets;
using Flux: train!;
using Flux: onehotbatch;
using Flux: MaxPool;
using Colors;
using BSON: @save
gr();
# load full training set
train_x, train_y = MNIST.traindata()
train_x = Flux.unsqueeze(train_x, 3) # add channels
train_y = onehotbatch(train_y, 0:9)
# load full test set
test_x, test_y = MNIST.testdata()
test_x = Flux.unsqueeze(test_x, 3)
test_y = onehotbatch(test_y, 0:9)
# to show image
# using Colors
# plot(Gray.(transpose(train_x[:, :, 1])))
# let's do a simple prediction neural network
network = Chain(
Conv((5, 5), 1 => 16, relu),
MaxPool((2, 2)),
Conv((5, 5), 16 => 32, relu),
MaxPool((2, 2)),
flatten,
Dense(512, 10),
softmax
)
parameters = Flux.params(network);
loss(x, y) = Flux.crossentropy(network(x), y); # anon function
opt = ADAM(0.001); # optimizer = gradient descent with learning rate
epochs = 10;
loss_history = Array{Float64}(undef, 0, 2) # one column for train data, one for validation
train_data = Flux.DataLoader((train_x, train_y); batchsize=128, shuffle=true)
@showprogress for i in 1:epochs
train!(loss, parameters, train_data, opt)
loss_history = [loss_history;
[loss(train_x, train_y) loss(test_x, test_y)]]
end
plot(loss_history, labels=["train" "validation"])
@save "mnist_cnn.bson" network
# show an image of a random prediction
random_i = rand(1:10000) # random value
plot(Gray.(transpose(test_x[:, :, 1, random_i]))) # plot the random image
actual_value = onecold(test_y[:, random_i], 0:9)
predicted_value = onecold(network(reshape(test_x[:, :, :, random_i],
(28, 28, 1, 1))), 0:9)[1]
annotate!((1, 0), text(
"actual value: $actual_value \npredicted value: $predicted_value ",
:top, :right, :red))