Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Kalman Filter convergence to better handle long time series #394

Open
jessegrabowski opened this issue Nov 23, 2024 · 0 comments
Open

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented Nov 23, 2024

For system with many states, kalman filtering is slow because we have to compute the so-called Kalman gain matrix, $$K = P_\text{prior}Z^T F^{-1}$$, with $$F = ZP_{\text{prior}}Z^T + H$$ From there, we can compute the posterior covariance matrix, $$P_\text{posterior} = P_\text{prior} - KFK^T $$ These matrices have the following shapes and meanings:

name description shape
Z Map hidden states to observed states $n_\text{obs} \times n_\text{hidden}$
P Hidden state covariance matrix $n_\text{hidden} \times n_\text{hidden}$
H Measurement error covariance $n_\text{obs} \times n_\text{obs}$
F Estimated residual covariance $n_\text{obs} \times n_\text{obs}$
K Kalman gain (optimal update to hidden states) $n_\text{hidden} \times n_\text{obs}$

Typically the number of observed states will be quite small relative to the number of hidden states, so the inversion of $F$ isn't so bad. But when the number of hidden states is large, computing $$P_\text{posterior}$$ requires at least 3 matrix multiplications. Actually it's more, because we use the "Joseph Form", $$P_\text{posterior} = (I - KZ) ) P_\text{prior} (I - KZ)^T + KHK^T$$, which guarantees the result is always positive semi-definite (because it avoids subtraction between two PSD matrices), but instead costs 9 matrix multiplications.

Then, to compute next-step forecasts, we have to do several more multiplications:

$$ P_\text{prior}^+ = T P_\text{posterior} T^T + R Q R^T $$

Where $$T$$ is the same size as $$P$$ (it's the system dynamics), while $$R$$ and $$Q$$ are relatively smaller (they have to do with how shocks enter the system).

My point is this is all quite expensive to compute. Interestingly though, $$P_\text{posterior}$$ does not depend on the data at all, and if there are no time-dependent matrices, the covariance matrix will converge to a fixed point given by:

$$P = TPT^T + RQR^T - TPZ^T(ZPZ^T +H)^{-1}ZPT^T$$

This is an algebraic riccati equation, and can be solved in pytensor as pt.linalg.solve_discrete_are(A=T.T, B=Z.T, Q=R @ Q @ R.T, R=H). Once we have this, we actually don't need to compute $P_\text{prior}$ or $P_\text{posterior}$, ever again, we can just use the steady-state $P$. Actually, once $P$ is fixed, so is $F^{-1}$, so the filter updates become extremely inexpensive.

How to use this is not 100% clear to me. I had previously made a SteadyStateFilter class that computed and used the steady-state $P$ from the first iteration. This means we don't do "online learning" for the first several steps. That seems OK to me, since we don't believe all that initial noise anyway. But on the other hand I've never seen this approach suggested in a textbook, so it makes me a bit suspicious that it's the right thing to do. I'm not against offering this option, but one negative is that JAX doesn't offer solve_discrete_are, and right now statespace is basically 100% dependent on JAX for sampling.

The "safer" option would be to use an ifelse in the update step to check for convergence. At every iteration of kalman_step, we can compute $||P_\text{prior}^- - P_\text{prior}^+||$ and stop computing updates once its below some tolerance. Here's a graph of the supremium norm for a structural model with level, trend, cycle, and 12 seasonal lags. $P_0$ was initialized to np.eye(15) * 1e7. The two plots are the same, but the right plot begins at t=20:

Untitled

Here's a table of tolerance levels and convergence iterations:

Tolerance Convergence Iteration
1 25
1e-1 49
1e-2 108
1e-3 216
1e-4 337
1e-5 457
1e-6 583
1e-7 703
1e-8 827

We could leave convergence tolerance as a free parameter for the user to play with. But we can see that if we pick 1e-2 for instance, anything after 100 time steps is basically free. This would be quite attractive for estimating large, expensive systems or extremely long time series.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant