-
Notifications
You must be signed in to change notification settings - Fork 1
/
cnn_cifar.m
135 lines (114 loc) · 4.49 KB
/
cnn_cifar.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
function [net, info] = cnn_cifar(varargin)
% CNN_CIFAR Demonstrates MatConvNet on CIFAR-10
% The demo includes two standard model: LeNet and Network in
% Network (NIN). Use the 'modelType' option to choose one.
run(fullfile(fileparts(mfilename('fullpath')), ...
'..', 'matlab', 'vl_setupnn.m')) ;
opts.modelType = 'lenet' ;
[opts, varargin] = vl_argparse(opts, varargin) ;
switch opts.modelType
case 'lenet'
opts.train.learningRate = [0.05*ones(1,15) 0.005*ones(1,10) 0.0005*ones(1,5)] ;
opts.train.weightDecay = 0.0001 ;
case 'nin'
opts.train.learningRate = [0.5*ones(1,30) 0.1*ones(1,10) 0.02*ones(1,10)] ;
opts.train.weightDecay = 0.0005 ;
otherwise
error('Unknown model type %s', opts.modelType) ;
end
opts.expDir = fullfile('data', sprintf('cifar4real-%s', opts.modelType)) ;
opts.expDir
[opts, varargin] = vl_argparse(opts, varargin) ;
opts.train.numEpochs = numel(opts.train.learningRate) ;
[opts, varargin] = vl_argparse(opts, varargin) ;
opts.dataDir = fullfile('data','cifar') ;
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');
opts.whitenData = true ;
opts.contrastNormalization = true ;
opts.train.batchSize = 100 ;
opts.train.continue = true ;
opts.train.gpus = [] ;
opts.numEpochs = 50 ;
opts.train.expDir = opts.expDir ;
opts = vl_argparse(opts, varargin) ;
% --------------------------------------------------------------------
% Prepare data and model
% --------------------------------------------------------------------
switch opts.modelType
case 'lenet', net = cnn_cifar_init(opts) ;
case 'nin', net = cnn_cifar_init_nin(opts) ;
end
%if exist(opts.imdbPath, 'file')
% imdb = load(opts.imdbPath) ;
%else
imdb = getCifarImdb(opts) ;
mkdir(opts.expDir) ;
save(opts.imdbPath, '-struct', 'imdb') ;
%end
% --------------------------------------------------------------------
% Train
% --------------------------------------------------------------------
[net, info] = cnn_train(net, imdb, @getBatch, ...
opts.train, ...
'val', find(imdb.images.set == 3)) ;
% --------------------------------------------------------------------
function [im, labels] = getBatch(imdb, batch)
% --------------------------------------------------------------------
im = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ;
if rand > 0.5, im=fliplr(im) ; end
% --------------------------------------------------------------------
function imdb = getCifarImdb(opts)
% --------------------------------------------------------------------
% Preapre the imdb structure, returns image data with mean image subtracted
unpackPath = fullfile(opts.dataDir, 'cifar-10-batches-mat');
files = [arrayfun(@(n) sprintf('data_batch_%d.mat', n), 1:5, 'UniformOutput', false) ...
{'test_batch.mat'}];
files = cellfun(@(fn) fullfile(unpackPath, fn), files, 'UniformOutput', false);
file_set = uint8([ones(1, 5), 3]);
if any(cellfun(@(fn) ~exist(fn, 'file'), files))
disp 'whyy';
url = 'http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz' ;
fprintf('downloading %s\n', url) ;
untar(url, opts.dataDir) ;
end
data = cell(1, numel(files));
labels = cell(1, numel(files));
sets = cell(1, numel(files));
for fi = 1:numel(files)
fd = load(files{fi}) ;
data{fi} = permute(reshape(fd.data',32,32,3,[]),[2 1 3 4]) ;
labels{fi} = fd.labels' + 1; % Index from 1
sets{fi} = repmat(file_set(fi), size(labels{fi}));
end
set = cat(2, sets{:});
data = single(cat(4, data{:}));
% remove mean in any case
dataMean = mean(data(:,:,:,set == 1), 4);
data = bsxfun(@minus, data, dataMean);
% normalize by image mean and std as suggested in `An Analysis of
% Single-Layer Networks in Unsupervised Feature Learning` Adam
% Coates, Honglak Lee, Andrew Y. Ng
if opts.contrastNormalization
z = reshape(data,[],60000) ;
z = bsxfun(@minus, z, mean(z,1)) ;
n = std(z,0,1) ;
z = bsxfun(@times, z, mean(n) ./ max(n, 40)) ;
data = reshape(z, 32, 32, 3, []) ;
end
if opts.whitenData
z = reshape(data,[],60000) ;
W = z(:,set == 1)*z(:,set == 1)'/60000 ;
[V,D] = eig(W) ;
% the scale is selected to approximately preserve the norm of W
d2 = diag(D) ;
en = sqrt(mean(d2)) ;
z = V*diag(en./max(sqrt(d2), 10))*V'*z ;
data = reshape(z, 32, 32, 3, []) ;
end
clNames = load(fullfile(unpackPath, 'batches.meta.mat'));
imdb.images.data = data ;
imdb.images.labels = single(cat(2, labels{:})) ;
imdb.images.set = set;
imdb.meta.sets = {'train', 'val', 'test'} ;
imdb.meta.classes = clNames.label_names;