-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotNumericalStability.m
96 lines (77 loc) · 2.54 KB
/
plotNumericalStability.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
% Visualize the numerical stability of all solvers by plotting the
% residuals when providing noiseless data.
% Second plot shows the number of times the solvers failed to find any
% real solution.
%% Setup.
m = 10;
n = 10;
iters = 100; % Increase for smoother histograms. 10000 in paper.
sigma = 0; % Noise in distance measurements.
solvers = getSolvers();
%% Solve random problems for all solvers.
ress = nan(length(solvers),iters);
failures = zeros(length(solvers),1);
% Define solver options (used for local optimization).
opts = struct();
opts.tol = 1e-9;
opts.maxIters = 20;
opts.refine = false;
for i=1:iters
fprintf('Iteration: %d/%d\n',i,iters);
% Contains the low rank matrices and ground truth measurement
% matrix.
prob = createRandomUpgradeProblem(m,n,3,sigma);
for k=1:length(solvers)
% Select solver.
solver = solvers(k);
% Create random minimal sample from the problem.
sample = createRandomSample(solver.id,prob);
% Solve minimal problem.
[Lti,q] = solver.solve(sample,opts);
if ~isempty(Lti)
for j=1:length(Lti)
Rhat = Lti{j}*sample.fullU;
Shat = Lti{j}'\(sample.fullV+q(:,j));
% Calculate error in distances.
Dhat = pdist2(Rhat',Shat');
ress(k,i) = min(ress(k,i),rms(prob.Dmeas(:)-Dhat(:)));
end
else
failures(k) = failures(k)+1;
end
end
end
%% Plot residual distributions.
edges = -17:0.25:1; % Figure 1 in paper.
% edges = -16:0.25:2; % Figure 2 in paper.
centers = edges(1:end-1)+diff(edges)/2;
% Swap the 2nd and 3rd line color for consistency with Figure 3.
lineColors = lines(length(solvers));
tmp = lineColors(2,:);
lineColors(2,:) = lineColors(3,:);
lineColors(3,:) = tmp;
figure(1);
for k=1:length(solvers)
hc = histcounts(log10(ress(k,:)),edges)/iters;
if k <= 7
plot(centers,hc,'-','LineWidth',2,'Color',lineColors(k,:));
else
plot(centers,hc,'--','LineWidth',2,'Color',lineColors(k,:));
end
hold on
end
hold off
title('Numerical stability for all solvers');
xlabel('log_{10}(error)');
% axis([edges(1) edges(end) 0 0.1]); % Figure 1 in paper.
axis([edges(1) edges(end) ylim]);
legend({solvers.name},'Location','EastOutside');
% set(gca,'FontName','Times');
% set(gca,'FontSize',14);
%% Plot failure rate.
figure(2);
bp = bar(failures/iters*100,'FaceColor','flat');
bp.CData = lineColors(1:length(solvers),:);
ylabel('%');
xticklabels({solvers.name});
title('Percentage of failures');