You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For system with many states, kalman filtering is slow because we have to compute the so-called Kalman gain matrix, $$K = P_\text{prior}Z^T F^{-1}$$, with $$F = ZP_{\text{prior}}Z^T + H$$ From there, we can compute the posterior covariance matrix, $$P_\text{posterior} = P_\text{prior} - KFK^T $$ These matrices have the following shapes and meanings:
name
description
shape
Z
Map hidden states to observed states
$n_\text{obs} \times n_\text{hidden}$
P
Hidden state covariance matrix
$n_\text{hidden} \times n_\text{hidden}$
H
Measurement error covariance
$n_\text{obs} \times n_\text{obs}$
F
Estimated residual covariance
$n_\text{obs} \times n_\text{obs}$
K
Kalman gain (optimal update to hidden states)
$n_\text{hidden} \times n_\text{obs}$
Typically the number of observed states will be quite small relative to the number of hidden states, so the inversion of $F$ isn't so bad. But when the number of hidden states is large, computing $$P_\text{posterior}$$ requires at least 3 matrix multiplications. Actually it's more, because we use the "Joseph Form", $$P_\text{posterior} = (I - KZ) ) P_\text{prior} (I - KZ)^T + KHK^T$$, which guarantees the result is always positive semi-definite (because it avoids subtraction between two PSD matrices), but instead costs 9 matrix multiplications.
Then, to compute next-step forecasts, we have to do several more multiplications:
$$ P_\text{prior}^+ = T P_\text{posterior} T^T + R Q R^T $$
Where $$T$$ is the same size as $$P$$ (it's the system dynamics), while $$R$$ and $$Q$$ are relatively smaller (they have to do with how shocks enter the system).
My point is this is all quite expensive to compute. Interestingly though, $$P_\text{posterior}$$ does not depend on the data at all, and if there are no time-dependent matrices, the covariance matrix will converge to a fixed point given by:
$$P = TPT^T + RQR^T - TPZ^T(ZPZ^T +H)^{-1}ZPT^T$$
This is an algebraic riccati equation, and can be solved in pytensor as pt.linalg.solve_discrete_are(A=T.T, B=Z.T, Q=R @ Q @ R.T, R=H). Once we have this, we actually don't need to compute $P_\text{prior}$ or $P_\text{posterior}$, ever again, we can just use the steady-state $P$. Actually, once $P$ is fixed, so is $F^{-1}$, so the filter updates become extremely inexpensive.
How to use this is not 100% clear to me. I had previously made a SteadyStateFilter class that computed and used the steady-state $P$ from the first iteration. This means we don't do "online learning" for the first several steps. That seems OK to me, since we don't believe all that initial noise anyway. But on the other hand I've never seen this approach suggested in a textbook, so it makes me a bit suspicious that it's the right thing to do. I'm not against offering this option, but one negative is that JAX doesn't offer solve_discrete_are, and right now statespace is basically 100% dependent on JAX for sampling.
The "safer" option would be to use an ifelse in the update step to check for convergence. At every iteration of kalman_step, we can compute $||P_\text{prior}^- - P_\text{prior}^+||$ and stop computing updates once its below some tolerance. Here's a graph of the supremium norm for a structural model with level, trend, cycle, and 12 seasonal lags. $P_0$ was initialized to np.eye(15) * 1e7. The two plots are the same, but the right plot begins at t=20:
Here's a table of tolerance levels and convergence iterations:
Tolerance
Convergence Iteration
1
25
1e-1
49
1e-2
108
1e-3
216
1e-4
337
1e-5
457
1e-6
583
1e-7
703
1e-8
827
We could leave convergence tolerance as a free parameter for the user to play with. But we can see that if we pick 1e-2 for instance, anything after 100 time steps is basically free. This would be quite attractive for estimating large, expensive systems or extremely long time series.
The text was updated successfully, but these errors were encountered:
For system with many states, kalman filtering is slow because we have to compute the so-called Kalman gain matrix,$$K = P_\text{prior}Z^T F^{-1}$$ , with $$F = ZP_{\text{prior}}Z^T + H$$ From there, we can compute the posterior covariance matrix, $$P_\text{posterior} = P_\text{prior} - KFK^T $$ These matrices have the following shapes and meanings:
Typically the number of observed states will be quite small relative to the number of hidden states, so the inversion of$F$ isn't so bad. But when the number of hidden states is large, computing $$P_\text{posterior}$$ requires at least 3 matrix multiplications. Actually it's more, because we use the "Joseph Form", $$P_\text{posterior} = (I - KZ) ) P_\text{prior} (I - KZ)^T + KHK^T$$ , which guarantees the result is always positive semi-definite (because it avoids subtraction between two PSD matrices), but instead costs 9 matrix multiplications.
Then, to compute next-step forecasts, we have to do several more multiplications:
Where$$T$$ is the same size as $$P$$ (it's the system dynamics), while $$R$$ and $$Q$$ are relatively smaller (they have to do with how shocks enter the system).
My point is this is all quite expensive to compute. Interestingly though,$$P_\text{posterior}$$ does not depend on the data at all, and if there are no time-dependent matrices, the covariance matrix will converge to a fixed point given by:
This is an algebraic riccati equation, and can be solved in pytensor as$P_\text{prior}$ or $P_\text{posterior}$ , ever again, we can just use the steady-state $P$ . Actually, once $P$ is fixed, so is $F^{-1}$ , so the filter updates become extremely inexpensive.
pt.linalg.solve_discrete_are(A=T.T, B=Z.T, Q=R @ Q @ R.T, R=H)
. Once we have this, we actually don't need to computeHow to use this is not 100% clear to me. I had previously made a$P$ from the first iteration. This means we don't do "online learning" for the first several steps. That seems OK to me, since we don't believe all that initial noise anyway. But on the other hand I've never seen this approach suggested in a textbook, so it makes me a bit suspicious that it's the right thing to do. I'm not against offering this option, but one negative is that JAX doesn't offer
SteadyStateFilter
class that computed and used the steady-statesolve_discrete_are
, and right now statespace is basically 100% dependent on JAX for sampling.The "safer" option would be to use an$||P_\text{prior}^- - P_\text{prior}^+||$ and stop computing updates once its below some tolerance. Here's a graph of the supremium norm for a structural model with level, trend, cycle, and 12 seasonal lags. $P_0$ was initialized to
ifelse
in theupdate
step to check for convergence. At every iteration ofkalman_step
, we can computenp.eye(15) * 1e7
. The two plots are the same, but the right plot begins at t=20:Here's a table of tolerance levels and convergence iterations:
We could leave convergence tolerance as a free parameter for the user to play with. But we can see that if we pick
1e-2
for instance, anything after 100 time steps is basically free. This would be quite attractive for estimating large, expensive systems or extremely long time series.The text was updated successfully, but these errors were encountered: