-
Notifications
You must be signed in to change notification settings - Fork 1
/
cnn_mnist_neglu.m
100 lines (82 loc) · 3.25 KB
/
cnn_mnist_neglu.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
function [net, info] = cnn_mnist_neglu(varargin)
% CNN_MNIST Demonstrated MatConNet on MNIST
run(fullfile(fileparts(mfilename('fullpath')),...
'..', 'matlab', 'vl_setupnn.m')) ;
opts.expDir = fullfile('data','mnistmixlu025') ;
[opts, varargin] = vl_argparse(opts, varargin) ;
opts.dataDir = fullfile('data','mnist') ;
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');
opts.useBnorm = false ;
opts.train.batchSize = 100 ;
opts.train.numEpochs = 20 ;
opts.train.continue = true ;
opts.train.gpus = [] ;
opts.train.learningRate = 0.001 ;
opts.train.expDir = opts.expDir ;
opts = vl_argparse(opts, varargin) ;
% --------------------------------------------------------------------
% Prepare data
% --------------------------------------------------------------------
if exist(opts.imdbPath, 'file')
imdb = load(opts.imdbPath) ;
else
imdb = getMnistImdb(opts) ;
mkdir(opts.expDir) ;
save(opts.imdbPath, '-struct', 'imdb') ;
end
net = cnn_mnist_neglu_init('useBnorm', opts.useBnorm) ;
% --------------------------------------------------------------------
% Train
% --------------------------------------------------------------------
[net, info] = cnn_train(net, imdb, @getBatch, ...
opts.train, ...
'val', find(imdb.images.set == 3)) ;
% --------------------------------------------------------------------
function [im, labels] = getBatch(imdb, batch)
% --------------------------------------------------------------------
im = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ;
% --------------------------------------------------------------------
function imdb = getMnistImdb(opts)
% --------------------------------------------------------------------
% Preapre the imdb structure, returns image data with mean image subtracted
files = {'train-images-idx3-ubyte', ...
'train-labels-idx1-ubyte', ...
't10k-images-idx3-ubyte', ...
't10k-labels-idx1-ubyte'} ;
if ~exist(opts.dataDir, 'dir')
mkdir(opts.dataDir) ;
end
for i=1:4
if ~exist(fullfile(opts.dataDir, files{i}), 'file')
url = sprintf('http://yann.lecun.com/exdb/mnist/%s.gz',files{i}) ;
fprintf('downloading %s\n', url) ;
gunzip(url, opts.dataDir) ;
end
end
f=fopen(fullfile(opts.dataDir, 'train-images-idx3-ubyte'),'r') ;
x1=fread(f,inf,'uint8');
fclose(f) ;
x1=permute(reshape(x1(17:end),28,28,60e3),[2 1 3]) ;
f=fopen(fullfile(opts.dataDir, 't10k-images-idx3-ubyte'),'r') ;
x2=fread(f,inf,'uint8');
fclose(f) ;
x2=permute(reshape(x2(17:end),28,28,10e3),[2 1 3]) ;
f=fopen(fullfile(opts.dataDir, 'train-labels-idx1-ubyte'),'r') ;
y1=fread(f,inf,'uint8');
fclose(f) ;
y1=double(y1(9:end)')+1 ;
f=fopen(fullfile(opts.dataDir, 't10k-labels-idx1-ubyte'),'r') ;
y2=fread(f,inf,'uint8');
fclose(f) ;
y2=double(y2(9:end)')+1 ;
set = [ones(1,numel(y1)) 3*ones(1,numel(y2))];
data = single(reshape(cat(3, x1, x2),28,28,1,[]));
dataMean = mean(data(:,:,:,set == 1), 4);
data = bsxfun(@minus, data, dataMean) ;
imdb.images.data = data ;
imdb.images.data_mean = dataMean;
imdb.images.labels = cat(2, y1, y2) ;
imdb.images.set = set ;
imdb.meta.sets = {'train', 'val', 'test'} ;
imdb.meta.classes = arrayfun(@(x)sprintf('%d',x),0:9,'uniformoutput',false) ;