-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make using JAX (or any accelerator) an Instanced Python class (and toggle) #509
base: master
Are you sure you want to change the base?
Conversation
…s of work still, including how to delay JIT Rebased from master
This PR overhuals how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself. This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future. * Importing is all handled through an `init_accelerator` method with matching name (i.e. `init_numpy` or `init_jax`). * All items which need to be set in the `mbar_solvers` namespace are set through the `global` word of Python in the `init_X` method and therefore are cast up to the full `mbar_solvers` namespace. * The `mbar_solvers` module now has state of the whole module and exists as ONE OR THE OTHER at any given time depending on when the last time the accelerator was set. I.e. You cannot have one MBAR object set as numpy and another set as JAX in the same code and expect them to operate with different libraries. * Default is JAX * I'm calling numpy an "accelerator" even though its the fall back.
Relies on creating classes with an exposed API. The problem is that JAX doesn't like acting on class methods so I am having to build around it.
…ecompile. Seems to run much slower in tests right now.
Casts methods to static methods every time to ensure that JAX is not serializing the class itself as constants, dramatically slowing down the code execution. Makes for slightly more complicated method call jumping, but otherwise uses the same code paths.
Solved the accelerator problem.
…o go into source code to find this flag) Fix lint complaining about jax import on no-jax systems by wrapping JAX and raising appropriately
Fix tests for jax-toggle
…ls are weird, so added a "real_jit" property that can be set on implementation.
Note: After some more testing, I think I was instancing the Solver classes everytime a new MBAR was created and that was what was slowing down my testing, that and I had a native Python I think I can remove all of the static method generator functions and go back to just clean class methods, and do the pytree registration for good measure just in case. It will clean up the code and maintain speed. Something for me to check and benchmark when I can. |
…nerated methods is negligible and my earlier testing was the fact I was re-instancing the solver, and thus re-JIT'ing everything each time a new MBAR was called (fixed in earlier commit). After testing, here are the results: Testing the timing of test_protocols test using static-generated methods as a relative baseline: The test is 99% as fast on average with PyTree registration. The test is 95% as fast on average without the PyTree registration. So I've opted to use the JAX PyTree registration method and simplify the code substantially by moving all methods back into self methods. Also updated the readme to reflect the new option.
After some testing, I found the speed gain from having pure static generated methods is negligible and my earlier testing was the fact I was re-instancing the solver, and thus re-JIT'ing everything each time a new MBAR was called (fixed in earlier commit). After testing, here are the results: Testing the timing of test_protocols test using static-generated methods as a relative baseline: So I've opted to use the JAX PyTree registration method and simplify the code substantially by moving all methods back into self methods. So this version is more pythonic, easier to read, almost as fast as pure static methods, and overall implements the API listed in #496. This is ready for review The only outstanding question I have is a naming convention: Do we want to keep the name "accelerator" as I have in most places, or "solver" which I have in a few others, they are just accelerated by the different libraries. The only API concern here is the keyword |
They say the pre 1.9 behavior was a bug, hence no depreciation warning.
@@ -96,6 +96,7 @@ def __init__( | |||
n_bootstraps=0, | |||
bootstrap_solver_protocol=None, | |||
rseed=None, | |||
accelerator=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this name is good
"Either install with pip or conda:\n" | ||
" pip install pybar[jax] \n" | ||
" OR \n" | ||
" conda install pymbar \n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add a -c conda-forge
in here, I didn't want to use the web UI suggestion features since I wouldn't be able to maintain the formatting of this warning
Also you can either fix this in your PR, or merge this one in #510 to fix the RTD builds |
Supersedes #508
This PR overhauls how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.
This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.
mbar_solvers
file has now been moved to its own module.MBARSolver
general class as was suggested in #508 by @jchoderambar_solvers
import, so any attempt toimport mbar_solvers
orfrom mbar_solvers import X,Y,Z,*
behave identically to how the main branch currently is. Thus the API is preserved and can keep the 4.y.z version scheme.The complicated part of this is casting the actual functions (e.g. gradient, log W nk, precondition u_kn, etc.) have to be generated as static methods, not tied to the actualMBARSolver
class. The way JIT for JAX works is it will serialize anything in the function, so anyobject
(i.e.class
andself
in function) has every one of its methods/parameters serialized as well at compile time, resulting in a massive slowdown if I try to just leave each of the methods as class methods. I get around this with a number ofgenerate_static...
methods and replace the actual definitions in the__init__
while still preserving the doc strings and API. In case you're wondering details, see theMBARSolver.__init__
doc stringEdit: After testing, correct function definitions, and PyTree assignments, this is not a problem and implementation 99% as fast as static methods on average.
Given the drastically different way the
mbar_solvers
is loaded, and the massive implementation change, I would like to again formally request input from @mikemhenry and @mrshirts about this approach. @jchodera as you suggested the API idea in #508, I'd like to as for your feedback as well given the implementation was not as clean due to the JIT shenanigans.@invemichele this is the full implementation of the outlined API and features in #496, so any input you have would also be appreciated.