-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make using JAX (or any accelerator) a toggle #508
Conversation
…s of work still, including how to delay JIT Rebased from master
This PR overhuals how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself. This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future. * Importing is all handled through an `init_accelerator` method with matching name (i.e. `init_numpy` or `init_jax`). * All items which need to be set in the `mbar_solvers` namespace are set through the `global` word of Python in the `init_X` method and therefore are cast up to the full `mbar_solvers` namespace. * The `mbar_solvers` module now has state of the whole module and exists as ONE OR THE OTHER at any given time depending on when the last time the accelerator was set. I.e. You cannot have one MBAR object set as numpy and another set as JAX in the same code and expect them to operate with different libraries. * Default is JAX * I'm calling numpy an "accelerator" even though its the fall back.
There is a comparable and probably more Pythonic way to do this that doesn't involve all the All the methods which have the overlapping namespaces could be contained in their own modules and then imported as a single module of methods. E.g. Hinvg = accelerator_operations["jax"].lstsq(H, g, rcond=-1)[0] In code we would set something like accelerator = "jax"
Hinvg = accelerator_operations[accelerator].lstsq(H, g, rcond=-1)[0] and then whatever MBAR setting is chosen it uses that accelerator. This still has the problem of the module having "state" as it were unless all the functional calls were wrapped to accept an "accelerator" keyword or something. I dont have a notion of how to set that initially for the JIT'd functions though correctly and would have to think about it. Something to consider. |
This is an incredibly naïve question, but is there a way that the accelerators could all subclass an Each accelerator could be implemented in a separate module (file), and the The complexities with Could this reduce the complexity in allowing each accelerator to simply be a subclass that manages its own specialized code, or fails at initialization with a common exception if the accelerator cannot be used in the current environment? |
Yes, we could do that, and that would be a much more Pythonic way to do it that doesn't require I had thought of putting all of the actual functional code into a single abstract class, and then each accelerator can implement their own accelerator methods for things like However, doing so requires changing the API of I'd like to not use this
Yes, its slightly more complicated by the use-case/issue I've been resolving in PR's leading up to this. Having JAX + decorated (or at module load) JIT calls is the 64-bit bitsize setting. #496 has the use case where using JAX for other code with its default 32-bit bitsize and PyMBAR needed 64-bit mode causes other JAX uses to break due to incompatible JAX-based models. #497 and my own tests confirm we cant really get by with 32-bit in PyMBAR. The PR's I have done before this one do stagger the JIT and the setting of 64-bit mode correctly and safely, allowing the loading of pymbar without breaking other JAX uses. If we can embed the JIT in a class, we can still stagger it. With JAX caching the JIT'd functions, we should be okay. |
I think this a good approach to solve our current problem in a way that keeps the API the same, but I agree that for pymbar 5.x, we should re-work this to not use globals. |
I think the current use-case (#496) is mitigated with the earlier PR's (#504 and #505), so this is just an extra feature with a bit of future proofing that I don't know if we even will use any time soon if at all, i.e. other accelerators support.
I do think I can rework this PR to do what @jchodera suggested of making this much more Pythonic with a class-based abstract class system which preserves code paths to avoid duplication. I would then make the How does that sound? |
Relies on creating classes with an exposed API. The problem is that JAX doesn't like acting on class methods so I am having to build around it.
…ecompile. Seems to run much slower in tests right now.
Casts methods to static methods every time to ensure that JAX is not serializing the class itself as constants, dramatically slowing down the code execution. Makes for slightly more complicated method call jumping, but otherwise uses the same code paths.
Solved the accelerator problem.
Due to the dramatically different approach I went with to make this much more-pythonic, I am closing this in favor of a new PR to discuss there. I'll migrate relevant discussions from here to it. |
This PR overhauls how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.
This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.
init_accelerator
method with matching name (i.e.init_numpy
orinit_jax
).mbar_solvers
namespace are set through theglobal
word of Python in theinit_X
method and therefore are cast up to the fullmbar_solvers
namespace.mbar_solvers
module now has state of the whole module and exists as ONE OR THE OTHER at any given time depending on when the last time the accelerator was set. I.e. You cannot have one MBAR object set as numpy and another set as JAX in the same code and expect them to operate with different libraries.Given the drastically different way the
mbar_solver
library is loaded relative to most other Python codes, I would like to formally request input from @mikemhenry and @mrshirts about this approach, as well as anyone else who wants to comment.@invemichele this is the full implementation of the outlined API and features in #496, so any input you have would also be appreciated.
A possible future task is to split out all the import logic into a different submodule/folder just to make the
mbar_solvers
module simpler to read, especially if we get more accelerators in the future