From 62279de8275618a79ccabec77cd90a484021bd61 Mon Sep 17 00:00:00 2001 From: Mo Chen Date: Fri, 30 Nov 2018 05:55:49 +0800 Subject: [PATCH] modify ldsEm to use ldsPca as initialization --- chapter13/LDS/ldsEm.m | 33 ++++++++++++++++++--------------- demo/ch13/lds_demo.m | 29 +++++++++++++++-------------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/chapter13/LDS/ldsEm.m b/chapter13/LDS/ldsEm.m index 7611595..d07620a 100644 --- a/chapter13/LDS/ldsEm.m +++ b/chapter13/LDS/ldsEm.m @@ -17,7 +17,7 @@ model = init(X,m); end tol = 1e-4; -maxIter = 1000; +maxIter = 2000; llh = -inf(1,maxIter); for iter = 2:maxIter % E-step @@ -29,20 +29,23 @@ llh = llh(2:iter); function model = init(X, k) -d = size(X,1); -model.mu0 = randn(k,1); -model.P0 = iwishrnd(eye(k),k); -model.A = randn(k,k); -model.G = iwishrnd(eye(k),k); -model.C = randn(d,k); -model.S = iwishrnd(eye(d),d); -% [A,C,Z] = ldsPca(X,k,3*k); -% model.mu0 = Z(:,1); -% model.P0 = ; -% model.A = A; -% model.C = C; -% model.G = ; -% model.S = ; +% d = size(X,1); +% model.mu0 = randn(k,1); +% model.P0 = iwishrnd(eye(k),k); +% model.A = randn(k,k); +% model.G = iwishrnd(eye(k),k); +% model.C = randn(d,k); +% model.S = iwishrnd(eye(d),d); +[A,C,Z] = ldsPca(X,k,3*k); +model.mu0 = Z(:,1); +E = Z(:,1:end-1)-Z(:,2:end); +model.P0 = (dot(E(:),E(:))/(k*size(E,2)))*eye(k); +model.A = A; +E = A*Z(:,1:end-1)-Z(:,2:end); +model.G = E*E'/size(E,2); +model.C = C; +E = C*Z-X(:,1:size(Z,2)); +model.S = E*E'/size(E,2); function model = maximization(X ,nu, U, Ezz, Ezy) n = size(X,2); diff --git a/demo/ch13/lds_demo.m b/demo/ch13/lds_demo.m index fe9e421..42742ae 100644 --- a/demo/ch13/lds_demo.m +++ b/demo/ch13/lds_demo.m @@ -1,19 +1,20 @@ close all; -%% Parameter +% Parameter clear; d = 2; -k = 2; -n = 50; +k = 3; +n = 100; -A = [1,1; - 0 1]; +A = [1,0,1; + 0 1,0; + 0,0,1]; G = eye(k)*1e-3; -C = [1 0; - 0 1]; +C = [1,0,0; + 0 1,0]; S = eye(d)*1e-1; -mu0 = [0; 0]; +mu0 = [0;0;0]; P0 = eye(k); model.A = A; @@ -54,9 +55,9 @@ axis equal hold off %% LDS Subspace -[A,C,z] = ldsPca(x,k,3*k); -y = C*z; -t = size(z,2); +[A,C,nu] = ldsPca(x,k,3*k); +y = C*nu; +t = size(y,2); figure; hold on plot(x(1,1:t), x(2,1:t), 'ro'); @@ -66,9 +67,9 @@ axis equal hold off %% LDS EM -[model, llh] = ldsEm(x,k); -nu = kalmanSmoother(model,x); -y = model.C*nu; +[tmodel, llh] = ldsEm(x,k); +nu = kalmanSmoother(tmodel,x); +y = tmodel.C*nu; figure hold on plot(x(1,:), x(2,:), 'ro');