forked from mwgeurts/libra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcda.m
executable file
·487 lines (462 loc) · 20.7 KB
/
cda.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
function result=cda(x,group,varargin)
%CDA performs linear and quadratic classical discriminant analysis
% on the data matrix x with known group structure.
%
% Required input arguments:
% x : training data set (matrix of size n by p).
% group : column vector containing the group numbers of the training
% set x. For the group numbers, any strict positive integer is
% allowed assuming that the first group is the one with
% the smallest group number.
%
% Optional input arguments:
% method : String which indicates whether a 'linear' (default) or 'quadratic'
% discriminant rule should be applied
% misclassif : String which indicates how to estimate the probability of
% misclassification. It can be based on the
% training data ('training'), a validation set ('valid'),
% or cross-validation ('cv'). Default is 'training'.
% membershipprob : Vector which contains the membership probability of each
% group (sorted by increasing group number). If no priors are given, they are
% estimated as the proportions of observations in the training set.
% valid : If misclassif was set to 'valid', this field should contain
% the validation set (a matrix of size m by p).
% groupvalid : If misclassif was set to 'valid', this field should contain the group numbers
% of the validation set (a column vector).
% predictset : Contains a new data set (a matrix of size mp by p) from which the
% class memberships are unknown and should be predicted.
% plots : If equal to 1, one figure is created with the training data and the
% tolerance ellipses for each group. This plot is
% only available for bivariate data sets. For technical reasons, a maximum
% of 6 groups is allowed. Default is one.
%
% Options for advanced users (input comes from the program RSIMCA.m with option 'classic' = 1):
%
% weightstrain : The weights for the training data. Corresponds to the flagtrain from RDA. (default = 1)
% weightsvalid : The weights for the validation data. Corresponds to the flagvalid from RDA. (default = 1)
% I/O: result=cda('plots',0,'misclassif','training','method','linear',...
% 'membershipprob',proportions,'valid',y,'groupvalid',groupy);
% The user should only give the input arguments that have to change their default value.
% The name of the input arguments needs to be followed by their value.
% The order of the input arguments is of no importance.
%
% Examples: out=cda(x,group,'method','linear')
% out=cda(x,group,'plots',0)
% out=cda(x,group,'valid',y,'groupvalid',groupy)
%
% The output is a structure containing the following fields:
% result.assignedgroup : If there is a validation set, this vector contains the assigned group numbers
% for the observations of the validation set. Otherwise it contains the
% assigned group numbers of the original observations based on the discriminant rules.
% result.scores : If there is a validation set, this column of size m contains the maximal discriminant
% scores for each observation from the validation set. Otherwise it is a columnvector of
% size n containing the maximal discriminant scores of the training set.
% result.method : String containing the method used to obtain
% the discriminant rules (either 'linear' or 'quadratic'). This
% is the same as the input argument method.
% result.cov : If method equals 'linear', this is a matrix containing the
% estimated pooled covariance matrix.
% If method equals 'quadratic', it is a cell array containing the covariances per group.
% result.center : A vector in which the rows contain the estimated centers of the groups.
% result.md : A vector of length n containing the mahalanobis distances of each observation from the training set to the
% center of its group.
% result.flagtrain : Observations from the training set whose mahalanobis distance exceeds a certain cut-off value
% can be considered as outliers and receive a flag equal to zero. The regular observations
% receive a flag 1. (See also mcdcov.m)
% result.flagvalid : Observations from the validation set whose mahalanobis distance (to the center of their group)
% exceeds a certain cut-off value can be considered as outliers and receive a
% flag equal to zero. The regular observations receive a flag 1.
% If there is no validation set, this field is equal to zero.
% result.grouppredict : If there is a prediction set, this vector contains the assigned group numbers
% for the observations of the prediction set.
% result.flagpredict : Observations from the new data set (predict) whose robust distance (to the center of their group)
% exceeds a certain cut-off value can be considered as overall outliers and receive a
% flag equal to zero. The regular observations receive a flag 1.
% If there is no prediction set, this field is
% equal to zero.
% result.membershipprob : A vector with the membership probabilities.
% result.misclassif : String containing the method used to estimate the misclassification probabilities
% (same as the input argument misclassif)
% result.groupmisclasprob : A vector containing the misclassification
% probabilities for each group.
% result.avemisclasprob : Overall probability of misclassification (weighted average of the misclassification
% probabilities over all groups).
% result.class : 'CDA'
% result.classic : Is equal to 0 since this analysis is a classical analysis.
% result.x : The training data set (same as the input argument x).
% result.group : The group numbers of the training set (same as the input argument group).
%
% This function is part of LIBRA: the Matlab Library for Robust Analysis,
% available at:
% http://wis.kuleuven.be/stat/robust.html
%
% Written by Nele Smets on 02/02/2004
% Last Update: 01/07/2005
%
if nargin<2
error('There are too few input arguments.')
end
%assigning default-values
[n,p]=size(x);
if size(group,1)~=1
group=group';
end
if n ~= length(group)
error('The number of observations is not the same as the length of the group vector!')
end
g=group;
countsorig=tabulate(g); %contingency table (outputmatrix with 3 colums: value - number - percentage)
[lev,levi,levj]=unique(g);
%Redefining the group number
if any(lev~= (1:length(lev)))
lev=1:length(lev);
g=lev(levj);
counts=tabulate(g);
else
counts=countsorig;
end
if ~all(counts(:,2)) %some groups have zero values, omit those groups
disp(['Warning: group(s) ', num2str(counts(counts(:,2)==0,1)'), 'are empty.']);
empty=counts(counts(:,2)==0,:);
counts=counts(counts(:,2)~=0,:);
else
empty=[];
end
if any(counts(:,2)<5)%some groups have less than 5 observations
error(['Group(s) ', num2str(counts(counts(:,2)<5,1)'), ' have less than 5 observations.']);
end
proportions=zeros(size(counts,1),1);
y=0; %initial values of the validation data set and its groupsvector
groupy=0;
counter=1;
weightstrain = ones(1,n);
weightsvalid = 0;
default=struct('plots',1,'misclassif','training','method','linear','membershipprob',proportions,...
'valid',y,'groupvalid',groupy,'weightstrain',weightstrain,'weightsvalid',weightsvalid,'predictset',[]);
list=fieldnames(default);
options=default;
IN=length(list);
i=1;
%reading the user's input
if nargin>2
%
%placing inputfields in array of strings
%
for j=1:nargin-2
if rem(j,2)~=0
chklist{i}=varargin{j};
i=i+1;
end
end
%
%Checking which default parameters have to be changed
% and keep them in the structure 'options'.
%
while counter<=IN
index=strmatch(list(counter,:),chklist,'exact');
if ~isempty(index) %in case of similarity
for j=1:nargin-2 %searching the index of the accompanying field
if rem(j,2)~=0 %fieldnames are placed on odd index
if strcmp(chklist{index},varargin{j})
I=j;
end
end
end
options=setfield(options,chklist{index},varargin{I+1});
index=[];
end
counter=counter+1;
end
end
%Checking prior (>0 )
prior=options.membershipprob;
if size(prior,1)~=1
prior=prior';
end
if sum(prior) ~= 0
epsilon=10^-4;
if (any(prior < 0) | (abs(sum(prior)-1)) > epsilon)
error('Invalid membership probabilities.')
end
end
ng=length(proportions);
if length(prior)~=ng
error('The number of membership probabilities is not the same as the number of groups.')
end
%%%%%%%%%%%%%%%%%%MAIN FUNCTION %%%%%%%%%%%%%%%%%%%%%
%Checking if a validation set is given
if strmatch(options.misclassif, 'valid','exact')
if options.valid==0
error(['The misclassification error will be estimated through a validation set',...
'but no validation set is given!'])
else
validx = options.valid;
validgrouping = options.groupvalid;
if size(validx,1)~=length(validgrouping)
error('The number of observations in the validation set is not the same as the length of its group vector!')
end
if size(validgrouping,1)~=1
validgrouping = validgrouping';
end
countsvalidorig=tabulate(validgrouping);
countsvalid=countsvalidorig(countsvalidorig(:,2)~=0,:);
if size(countsvalid,1)==1
error('The validation set must contain observations from more than one group!')
elseif any(ismember(empty,countsvalid(:,1)))
error(['Group(s)' , num2str(empty(ismember(empty,countsvalid(:,1)))), 'was/were empty in the original dataset.'])
end
if (length(options.weightsvalid) == 1) | (length(options.weightsvalid)~=size(validx,1))
options.weightsvalid = ones(size(validx,1),1);
end
end
elseif options.valid~=0
validx = options.valid;
validgrouping = options.groupvalid;
if size(validx,1) ~= length(validgrouping)
error('The number of observations in the validation set is not the same as the length of its group vector!')
end
if size(validgrouping,1)~=1
validgrouping = validgrouping';
end
options.misclassif='valid';
countsvalidorig=tabulate(validgrouping);
countsvalid=countsvalidorig(countsvalidorig(:,2)~=0,:);
if size(countsvalid,1)==1
error('The validation set must contain more than one group!')
elseif any(ismember(empty,countsvalid(:,1)))
error(['Group(s)' , num2str(empty(ismember(empty,countsvalid(:,1)))), 'was/were empty in the original dataset.'])
end
if (length(options.weightsvalid) == 1) | (length(options.weightsvalid)~=size(validx,1))
options.weightsvalid = ones(size(validx,1),1);
end
end
%Discriminant rule based on the training set x
result1 = rawrule(x, g,prior, options.method);
%Discriminant rule based on reweighted results
if strmatch(options.misclassif,'valid','exact')
result2 = rewrule(validx, result1);
finalgroup = result2.class;
else
result2 = rewrule(x, result1);
finalgroup = result2.class;
end
%Apply discriminant rule on validation set
switch options.misclassif
case 'valid'
[v,vi,vj]=unique(validgrouping);
%Redefining the group number
if any(v~= (1:length(v)))
v=1:length(v);
validgrouping=v(vj);
elseif size(validgrouping,1)~=1
validgrouping=validgrouping';
end
if any(countsvalidorig(:,2)==0)
empty=setdiff(countsvalidorig(find(countsvalidorig(:,2)==0),1), countsorig(find(countsorig(:,2)==0)));
disp(['Warning: the test group(s) ' , num2str(empty'), ' are empty']);
else
empty=[];
end
misclas=-ones(1,length(lev));
for i=1:size(validx,1)
if strmatch(options.method,'quadratic','exact')
dist(i) = mahalanobis(validx(i,:), result1.center(vj(i),:),'invcov',result1.invcov{vj(i)});
else
dist(i) = mahalanobis(validx(i, :), result1.center(vj(i),:),'invcov',result1.invcov);
end
end
weightsvalid=zeros(1,length(dist));
weightsvalid(dist <= chi2inv(0.975, p))=1;
for i=1:length(v)
if ~isempty(intersect(i,v))
misclas(i)=sum((validgrouping(options.weightsvalid == 1)==finalgroup(options.weightsvalid == 1)')...
& (validgrouping(options.weightsvalid == 1)==repmat(lev(i),1,sum(options.weightsvalid))));
ingroup(i) = sum((validgrouping(options.weightsvalid == 1) == repmat(lev(i), 1,sum(options.weightsvalid))));
misclas(i) = (1 - (misclas(i)./ingroup(i)));
end
end
if any(misclas==-1)
misclas(misclas==-1)=0;
end
misclasprobpergroup=misclas;
misclas=misclas.*result1.prior;
misclasprob=sum(misclas);
case 'training'
for i=1:ng
misclas(i) = sum((g(options.weightstrain==1)==finalgroup(options.weightstrain==1)')...
&(g(options.weightstrain==1)==repmat(lev(i),1,sum(options.weightstrain))));
ingroup(i) = sum((g(options.weightstrain==1) == repmat(lev(i),1,sum(options.weightstrain))));
end
misclas = (1 - (misclas./ingroup));
misclasprobpergroup = misclas;
misclas = misclas.*result1.prior;
misclasprob = sum(misclas);
weightsvalid=0;%only available with validation set
case 'cv'
finalgroup=[];
for i=1:length(x)
xnew=removal(x,i,0);
groupnew=removal(group,0,i);
functie1res = rawrule(xnew, groupnew,prior, options.method);
functie2res = rewrule(x(i, :), functie1res);
finalgroup = [finalgroup; functie2res.class(1)];
end
for i=1:ng
misclas(i) = sum((g(options.weightstrain==1) == finalgroup(options.weightstrain==1)')...
& (g(options.weightstrain==1) == repmat(lev(i),1,sum(options.weightstrain))));
ingroup(i) = sum(g(options.weightstrain==1) == repmat(lev(i),1,sum(options.weightstrain)));
end
misclas = (1 - (misclas./ingroup));
misclasprobpergroup= misclas;
misclas = misclas.* result1.prior;
misclasprob = sum(misclas);
weightsvalid=0; %only available with validation set
end
%classify the new observations (predict)
if ~isempty(options.predictset)
resultpredict = rewrule(options.predictset, result1);
finalgrouppredict = resultpredict.class;
for i=1:size(options.predictset,1)
for j = 1:ng
if strmatch(options.method,'quadratic','exact')
distpredict(i,j) = mahalanobis(options.predictset(i,:), result1.center(j,:),'invcov',result1.invcov{j});
else
distpredict(i,j) = mahalanobis(options.predictset(i, :), result1.center(j,:),'invcov',result1.invcov);
end
end
end
weightspredict = zeros(1,size(distpredict,1));
weightspredict(min(distpredict,[],2) <= chi2inv(0.975, p))=1;
else
finalgrouppredict = 0;
weightspredict = 0;
end
%Output structure
result=struct('assignedgroup',{finalgroup'},'scores',{result2.scores'},'method',{result2.method},...
'cov',{result1.cov},'center',{result1.center},'md',{result1.mahal'}, 'flagtrain',{result1.flag'},...
'flagvalid',{weightsvalid},'grouppredict',finalgrouppredict,'flagpredict',weightspredict','membershipprob',{result1.prior},...
'misclassif',{options.misclassif},'groupmisclasprob',{misclasprobpergroup},...
'avemisclasprob',{misclasprob},'class',{'CDA'},'x',{x},'group',{group});
if size(x,2)~=2
result=rmfield(result,{'x','group'});
end
%Plotting the output
try
if options.plots
makeplot(result);
end
catch %output must be given even if plots are interrupted
%> delete(gcf) to get rid of the menu
end
%--------------------------------------------------------------------------
function result=rawrule(x, grouping,prior, method)
%calculate the discrimination rule based on the training set x.
[n,p]=size(x);
g=grouping;
epsilon=10^-4;
[gun,gi,gj]=unique(g);
ng=length(gun);
switch method
case 'linear' %equal covariances supposed
wsum=0;
for j=1:length(gun)
group.cov{j}=cov(x(g==gun(j),:));
wsum=wsum+sum(g==gun(j))*group.cov{j};
group.center(j,:)=mean(x(g==gun(j),:)); %center of all groups, matrix of ng x p
end
covar=wsum/n;
for i=1:n
zgeg(i,:)=x(i,:)-group.center(gj(i),:);
end
dist=zeros(n,1);
for j=1:length(gun)
dist(g==gun(j))=mahalanobis(x(g==gun(j),:),group.center(j,:),'invcov',inv(cov(zgeg)));
end
weights=zeros(n,1);
weights(dist <= chi2inv(0.975,p))=1;
result.cov=covar; %over all group
result.invcov=inv(covar);
result.flag=weights;
result.mahal=dist;
result.center=group.center; %all groups
result.method=method;
case 'quadratic'
for j=1:length(gun)
group.cov{j}=cov(x(g==gun(j),:)); %covariance of group j
group.invcov{j}=inv(group.cov{j});
group.center(j,:)=mean(x(g==gun(j),:)); %center of all groups
end
for i=1:n
xdist(i)=mahalanobis(x(i,:), group.center(gj(i),:), 'invcov',group.invcov{gj(i)});
end
weights=zeros(n,1);
weights(xdist <= chi2inv(0.975,p))=1;
result.cov=group.cov; %per group
result.invcov=group.invcov;
result.center = group.center; %all groups
result.mahal = xdist';
result.flag = weights;
result.method = method;
end
if sum(prior) == 0
counts = tabulate(g);
if ~any(counts(:,2))
disp(['Warning: the group(s) ', num2str(counts(counts(:,2) == 0,1)'), 'contain only outliers']);
counts=counts(counts(:,2)~=0,:);
end
result.prior = (counts(:,3)/100)';
else
result.prior = prior;
end
%--------------------------------------------------------------------------
function result=rewrule(x, rawobject)
epsilon=10^-4;
center=rawobject.center;
covar=rawobject.cov;
invcov=rawobject.invcov;
prior=rawobject.prior;
method=rawobject.method;
if (length(prior) == 0 | length(prior) ~= size(center,1))
error('invalid prior')
end
if sum(prior)~=0
if (any(prior < 0) | (abs(sum(prior)-1)) > epsilon)
error('invalid prior')
end
end
ngroup=length(prior);
[n,p]=size(x);
switch method
case 'linear'
for j=1:ngroup
for i=1:n
scores(i,j) = linclassification(x(i,:)', center(j,:)', invcov, prior(j));
end
end
[maxs,maxsI] = max(scores,[],2);
for i=1:n
maxscore(i,1) = scores(i,maxsI(i));
end
result.scores = maxscore;
result.class = maxsI;
result.method = method;
case 'quadratic'
for j=1:ngroup
for i=1:n
scores(i, j) = classification(x(i,:)', center(j,:)', covar{j}, invcov{j}, prior(j));
end
end
[maxs,maxsI] = max(scores,[],2);
for i=1:n
maxscore(i,1) = scores(i,maxsI(i));
end
result.scores = maxscore;
result.class = maxsI;
result.method = method;
end
%--------------make sure the inputvariables are columnvectors!
function out=classification(x, center, covar,invcov, priorprob)
out=-0.5*log(abs(det(covar)))-0.5*(x - center)' * invcov *(x - center)+log(priorprob);
%-------------------
function out=linclassification(x, center, invcov, priorprob)
out=center'*invcov* x - 0.5*center'*invcov*center+log(priorprob);