-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Silently changing JAX default config #496
Comments
Thanks for opening this bug report @invemichele ❤️ (At first I thought this was about CPU architecture but it is about 64 bit floats https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) Testing to see what happens in our CI here: #497 if I remove 64 bit floats If we do need them, we should document this behavior and also print a warning. This will be important because I am not sure what happens if someone has already started jax and then we try and change the config since the docs say:
|
Looks like we do need 64 bit floats (which is what I thought) @invemichele I can add a warning and improve documentation, is that sufficient? It doesn't look like we can dynamically change Jax global configs after startup unfortunately |
yes, that would be useful info for people that want to use JAX and pymbar in the same script |
I've been tinkering with this some more, and I don't think there is a reasonable way to expect PyMBAR to operate in 32 bit mode. Implementing setting 32-bit mode is a bit tricky, but I have an implementation. The problem is there is no reliable way to expect useful outputs. You can get them, but its not reliably accurate or converging. @invemichele (and others) Is a loud warning sufficient for you use case here, or would it be extremely useful to be able to force 32-bit mode for PyMBAR under the assumption that you're not guaranteed converged results (I would also issue a very loud warning about using 32-bit floats from PyMBAR if thats the case). |
To be clear, I am not interested in pymbar using float32. The issue is that by changing the global JAX default, pymbar is in practice incompatible with other JAX code. My notebook with some JAX neural network, stopped working as soon as I imported pymbar and, since the error came from one of my JAX lines, it was not clear to me that the problem was pymbar until I went through its source code. A warning about the changed JAX global setting would be very useful for debugging. |
That would be a viable option. I think we can expand that to have the safest and most user-controllable approach. From what I can interpret; so long as the JIT'd functions haven't been called yet, we can still set the 64-bit mode. So how does this sound:
|
@invemichele I've got the warning for this in #504. Functionally, the JAX config is not set until right before the first JIT call and will issue this pair of warnings: On import (if 32-bit JAX):
On change to 64-bit mode:
I realize I still haven't set the API call, but the warnings are what I wanted to do for this PR first so I don't break the API in testing on top of changing the import logic. |
Magic word closed this, my mistake. Not ready to close until the API to toggle JAX is in. |
In trying to develop the API side of this, I realize this warning doesn't do any real good because the JIT decorators all activate on import before any of the actual functions are called because of how they work. I can disable the jit of functions with a global parameter, but I don't know how to check each function once its called to set the x64 flag and then jit. I need to delay the actions of the decorator until execution. Re-thinking the code. now. Even though the current merged version doesn't stop the 64-bit setting on import, it will very loudly warn you at least for now. |
Got a fix in #505. Once in I can can carry this over to the API. |
I run into an error while trying to load a JAX NN model, and it took me a while to realize the problem was caused by
import pymbar
. Here JAX global default is changed to x64, which was incompatible with my stored model.I solved by setting
force_no_jax = True
here, but probably it would be nice to have a warning somewhere about this global config change, or mention it somewhere in the doc.The text was updated successfully, but these errors were encountered: