Skip to content

Commit

Permalink
modify ldsEm to use ldsPca as initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
sth4nth committed Nov 29, 2018
1 parent ca599be commit 62279de
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 29 deletions.
33 changes: 18 additions & 15 deletions chapter13/LDS/ldsEm.m
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
model = init(X,m);
end
tol = 1e-4;
maxIter = 1000;
maxIter = 2000;
llh = -inf(1,maxIter);
for iter = 2:maxIter
% E-step
Expand All @@ -29,20 +29,23 @@
llh = llh(2:iter);

function model = init(X, k)
d = size(X,1);
model.mu0 = randn(k,1);
model.P0 = iwishrnd(eye(k),k);
model.A = randn(k,k);
model.G = iwishrnd(eye(k),k);
model.C = randn(d,k);
model.S = iwishrnd(eye(d),d);
% [A,C,Z] = ldsPca(X,k,3*k);
% model.mu0 = Z(:,1);
% model.P0 = ;
% model.A = A;
% model.C = C;
% model.G = ;
% model.S = ;
% d = size(X,1);
% model.mu0 = randn(k,1);
% model.P0 = iwishrnd(eye(k),k);
% model.A = randn(k,k);
% model.G = iwishrnd(eye(k),k);
% model.C = randn(d,k);
% model.S = iwishrnd(eye(d),d);
[A,C,Z] = ldsPca(X,k,3*k);
model.mu0 = Z(:,1);
E = Z(:,1:end-1)-Z(:,2:end);
model.P0 = (dot(E(:),E(:))/(k*size(E,2)))*eye(k);
model.A = A;
E = A*Z(:,1:end-1)-Z(:,2:end);
model.G = E*E'/size(E,2);
model.C = C;
E = C*Z-X(:,1:size(Z,2));
model.S = E*E'/size(E,2);

function model = maximization(X ,nu, U, Ezz, Ezy)
n = size(X,2);
Expand Down
29 changes: 15 additions & 14 deletions demo/ch13/lds_demo.m
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
close all;
%% Parameter
% Parameter
clear;
d = 2;
k = 2;
n = 50;
k = 3;
n = 100;

A = [1,1;
0 1];
A = [1,0,1;
0 1,0;
0,0,1];
G = eye(k)*1e-3;

C = [1 0;
0 1];
C = [1,0,0;
0 1,0];
S = eye(d)*1e-1;

mu0 = [0; 0];
mu0 = [0;0;0];
P0 = eye(k);

model.A = A;
Expand Down Expand Up @@ -54,9 +55,9 @@
axis equal
hold off
%% LDS Subspace
[A,C,z] = ldsPca(x,k,3*k);
y = C*z;
t = size(z,2);
[A,C,nu] = ldsPca(x,k,3*k);
y = C*nu;
t = size(y,2);
figure;
hold on
plot(x(1,1:t), x(2,1:t), 'ro');
Expand All @@ -66,9 +67,9 @@
axis equal
hold off
%% LDS EM
[model, llh] = ldsEm(x,k);
nu = kalmanSmoother(model,x);
y = model.C*nu;
[tmodel, llh] = ldsEm(x,k);
nu = kalmanSmoother(tmodel,x);
y = tmodel.C*nu;
figure
hold on
plot(x(1,:), x(2,:), 'ro');
Expand Down

0 comments on commit 62279de

Please sign in to comment.