Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make using JAX (or any accelerator) a toggle #508

Closed
wants to merge 8 commits into from

Conversation

Lnaden
Copy link
Contributor

@Lnaden Lnaden commented Jun 16, 2023

This PR overhauls how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.

This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.

  • Importing is all handled through an init_accelerator method with matching name (i.e. init_numpy or init_jax).
  • All items which need to be set in the mbar_solvers namespace are set through the global word of Python in the init_X method and therefore are cast up to the full mbar_solvers namespace.
  • The mbar_solvers module now has state of the whole module and exists as ONE OR THE OTHER at any given time depending on when the last time the accelerator was set. I.e. You cannot have one MBAR object set as numpy and another set as JAX in the same code and expect them to operate with different libraries.
  • Default is JAX
  • I'm calling numpy an "accelerator" even though its the fall back.

Given the drastically different way the mbar_solver library is loaded relative to most other Python codes, I would like to formally request input from @mikemhenry and @mrshirts about this approach, as well as anyone else who wants to comment.

@invemichele this is the full implementation of the outlined API and features in #496, so any input you have would also be appreciated.

A possible future task is to split out all the import logic into a different submodule/folder just to make the mbar_solvers module simpler to read, especially if we get more accelerators in the future

Lnaden added 2 commits June 16, 2023 14:31
…s of work still, including how to delay JIT

Rebased from master
This PR overhuals how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.

This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.

* Importing is all handled through an `init_accelerator` method with matching name (i.e. `init_numpy` or `init_jax`).
* All items which need to be set in the `mbar_solvers` namespace are set through the `global` word of Python in the `init_X` method and therefore are cast up to the full `mbar_solvers` namespace.
* The `mbar_solvers` module now has state of the whole module and exists as ONE OR THE OTHER at any given time depending on when the last time the accelerator was set. I.e. You cannot have one MBAR object set as numpy and another set as JAX in the same code and expect them to operate with different libraries.
* Default is JAX
* I'm calling numpy an "accelerator" even though its the fall back.
@Lnaden Lnaden requested review from mrshirts and mikemhenry June 16, 2023 20:56
@codecov
Copy link

codecov bot commented Jun 16, 2023

Codecov Report

Merging #508 (511d867) into master (cfe49fc) will increase coverage by 0.01%.
The diff coverage is 98.30%.

@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 16, 2023

There is a comparable and probably more Pythonic way to do this that doesn't involve all the globals and the init_X functions, but obscures whats happening in other ways.

All the methods which have the overlapping namespaces could be contained in their own modules and then imported as a single module of methods. E.g. accelerator_operations would be a module that imports pad, s_, jit, optimize, etc. and then we cast each module to a dictionary in the main mbar_solvers library so we can use them like

Hinvg = accelerator_operations["jax"].lstsq(H, g, rcond=-1)[0]

In code we would set something like

accelerator = "jax"
Hinvg = accelerator_operations[accelerator].lstsq(H, g, rcond=-1)[0]

and then whatever MBAR setting is chosen it uses that accelerator. This still has the problem of the module having "state" as it were unless all the functional calls were wrapped to accept an "accelerator" keyword or something. I dont have a notion of how to set that initially for the JIT'd functions though correctly and would have to think about it.

Something to consider.

@jchodera
Copy link
Member

jchodera commented Jun 19, 2023

This is an incredibly naïve question, but is there a way that the accelerators could all subclass an MBARAccelerator base class with a simple API that exposes the major computations that the accelerator needs to perform?

Each accelerator could be implemented in a separate module (file), and the __init__.py that controls how the pymbar.accelerators subpackage is loaded could load all available accelerators and provide a way to retrieve the current "fastest" and available ones.

The complexities with jax are that the JIT requires some time to create the JITed methods, and that this is currently managed by various decorators at the file level. But these functions could be created and JITed in the initmethod by callingjit(...)` directly.

Could this reduce the complexity in allowing each accelerator to simply be a subclass that manages its own specialized code, or fails at initialization with a common exception if the accelerator cannot be used in the current environment?

@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 19, 2023

could all subclass an MBARAccelerator base class with a simple API that exposes the major computations

Yes, we could do that, and that would be a much more Pythonic way to do it that doesn't require global methods which are very non-standard in most python codes I've seen. The main reason I didn't is because the mbar_solvers module is functional in its use, not class based.

I had thought of putting all of the actual functional code into a single abstract class, and then each accelerator can implement their own accelerator methods for things like pad and s_; and still keep the actual solver code identical, which has been really useful for implementing JAX without having to rewrite a fully unique solution for JAX and non-JAX.

However, doing so requires changing the API of mbar_solvers to be instance based, which would require bumping the major API version (i.e. pymbar 4.x -> pymbar 5.x) to keep with semantic versioning. We could in theory set default (jax) based methods in the mbar_solvers module which keeps the API in tact and doesn't require a major version update, which we can then put a DeprecationWarning in if we really want to.

I'd like to not use this init_X method if possible, keep the same codepaths, and not break the API all at the same time, so I'm open to ideas.

The complexities with jax are that the JIT requires some time to create the JITed methods...

Yes, its slightly more complicated by the use-case/issue I've been resolving in PR's leading up to this. Having JAX + decorated (or at module load) JIT calls is the 64-bit bitsize setting. #496 has the use case where using JAX for other code with its default 32-bit bitsize and PyMBAR needed 64-bit mode causes other JAX uses to break due to incompatible JAX-based models. #497 and my own tests confirm we cant really get by with 32-bit in PyMBAR. The PR's I have done before this one do stagger the JIT and the setting of 64-bit mode correctly and safely, allowing the loading of pymbar without breaking other JAX uses.

If we can embed the JIT in a class, we can still stagger it. With JAX caching the JIT'd functions, we should be okay.

@mikemhenry
Copy link
Contributor

I think this a good approach to solve our current problem in a way that keeps the API the same, but I agree that for pymbar 5.x, we should re-work this to not use globals.

@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 20, 2023

a good approach to solve our current problem

I think the current use-case (#496) is mitigated with the earlier PR's (#504 and #505), so this is just an extra feature with a bit of future proofing that I don't know if we even will use any time soon if at all, i.e. other accelerators support.

in a way that keeps the API the same, but I agree that for pymbar 5.x, we should re-work this to not use globals.

I do think I can rework this PR to do what @jchodera suggested of making this much more Pythonic with a class-based abstract class system which preserves code paths to avoid duplication. I would then make the mbar_solvers module try to load a default one like we have now, and then expose the methods to preserve the API so we can keep it as a pymbar 4.{Y+1}. I think I would much rather do that rather than this init_X + global approach I have here if the plan is to scrap it anyways.

How does that sound?

Lnaden and others added 6 commits June 20, 2023 14:40
Relies on creating classes with an exposed API. The problem is that JAX doesn't like acting on class methods so I am having to build around it.
…ecompile. Seems to run much slower in tests right now.
Casts methods to static methods every time to ensure that JAX is not serializing the class itself as constants, dramatically slowing down the code execution. Makes for slightly more complicated method call jumping, but otherwise uses the same code paths.
@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 22, 2023

Due to the dramatically different approach I went with to make this much more-pythonic, I am closing this in favor of a new PR to discuss there. I'll migrate relevant discussions from here to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants