-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(contrib.hsgp): convert matern spectral density from frequency domain #1811
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow! Thanks for checking this! It is a great idea! It is easy to get lost in the pi terms 😄
@brendancooley, do we also need to fix the formula in the LaTeX docstring? |
Good call! Just pushed fix |
@pytest.mark.parametrize( | ||
argnames="x1, x2, length, ell", | ||
argvalues=[ | ||
(jnp.array([[1.0]]), jnp.array([[0.0]]), jnp.array([1.0]), 5.0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change those global arrays into numpy arrays np.array(...)
? We don't want to trigger any global jax code because users might need to change floating precision or the default device platform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. We had a bunch of examples of this pattern in the hsgp test suite. Switched them all over and added some handling where needed for numpy arrays in the source code c9c1457
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests are failing because a type OR operation. I'm not sure about the type hint to use... Maybe we need
from jaxlib.xla_extension import ArrayImpl
?
Or simply take the OR operator outside the isinstance
😄?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python 3.9...
let me know how this looks to you 56e35f5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is green ✅🙌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a JAX typing expert but this looks very readable 😄
As I was implementing these tests (to support development in #1805) I noticed that the Matérn spectral density implementation was not giving results consistent with scikit-learn. After staring at Rasmussen and Williams Chapter 4 for a bit I finally noticed that their spectral densities are in the frequency domain, with the inputs scaled by$2 \pi$ . Riutort-Mayol 2023 adopt Rasmussen and Williams' parameterization initially but note that the input vectors are frequencies and convert appropriately for the specific $\nu = \infty$ , $\nu =3/2$ and $\nu =5/2$ examples (see screenshots below).
After dropping the$4 \pi^2$ term from the Matern spectral density function I am able to replicate the exact covariance functions from scikit-learn.
@juanitorduz let me know if I'm missing anything or misunderstanding here -- still trying to wrap my head around some of this stuff and get the terminology right. Additionally, let me know if you have any suggestions on how to make the test suite stronger. The goal is to add vector-valued alpha and length examples to the tests and use those to guide development on #1805.