Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional JAX dependency #500

Closed
Lnaden opened this issue May 25, 2023 · 30 comments
Closed

Optional JAX dependency #500

Lnaden opened this issue May 25, 2023 · 30 comments

Comments

@Lnaden
Copy link
Contributor

Lnaden commented May 25, 2023

@SimonBoothroyd on the conda-forge feedstock has requested a pymbar where JAX is optional for a leaner install where the full MBAR solver is not needed (for their use case). This can be done with no modifications to the code, only the build, since all the JAX calls are optional and checked at runtime; i think.

This would require a Conda and Pip solution at the same time.

The Conda-forge/build solution is fairly easy with the outputs directive allowing multiple builds. I've worked on the QCFractal package which does that and have the example here; it matches this use case pretty precisely.

The Pip solution could be solved by making JAX a feature instead of default installed through an Optional Dependency. The naive solution would have jax be an extra_requires keyword, resulting in two commands like:

  • pip install pymbar
  • pip install pymbar[jax] both being valid, the later one get JAX.

I personally am against this approach because I think the idea is we want JAX installed by default but would like other opinions as well. I don't know of an extra_requires (setup.py variant) or a project.optional-dependencies (pyproject.toml variant) which blocks/omits primary dependencies. There are other options like having multiple build projects, but I don't want to explore those options if the above is acceptable to everyone else due to increased build/maintenance complexity.

@mikemhenry
Copy link
Contributor

+1 on a pymbar-core or something that doesn't have jax on conda forge, and yes I don't think there is a way to do it with pip without doing a hack like setting PYMBAR_NO_JAX=TRUE as an envar and then picking that up in setup.py, but I don't think you can hack that in with a pyproject.toml

I think I could live with the docs saying "we recomend using conda-forge, if you need to use pip, install it like this pip install pymbar[jax]. We could also print a warning when a user installs it like pip install pymbar and say something like "Warning, installed without jax support, [setps on how to reinstall with jax support]"

@Lnaden
Copy link
Contributor Author

Lnaden commented May 25, 2023

We could also print a warning

I dont know how to issue a message at all with a pip install, but it would be an easy thing to add a warning/info on Pymbar invocation. I'd be in favor of this.

a hack like setting PYMBAR_NO_JAX=TRUE as an envar

I'd like to avoid hacky solutions when there are solutions, just not for the particular paradigm I suggested (optional thing removes dependencies). And I very much would like to avoid envar anything.

@mrshirts
Copy link
Collaborator

mrshirts commented May 25, 2023

I know that alchemlyb is a little leery of jax; @xiki-tempula could you maybe speak up on this and your thoughts? . . .

@mrshirts
Copy link
Collaborator

We certainly could have a warning at run that it's downloaded without jax, and might be slower, with information on where to get jax.

@xiki-tempula
Copy link

I wonder for pymbar 4. Is it like without JAX, one cannot run mbar at all or it is just slower?

I think JAX is more of a problem for pip as it is installed every time when one setup a new env.

For conda, this is much less of an issue as conda would only need to install JAX once and mamba is much faster than pip.
I guess pymbar-core could be one solution and one could also have it as pymbar=*=*JAX where a build varient installs Jax as conda dependency and pymbar=*=*core which would be another build that don't install JAX. But I feel for conda the JAX is much less of an issue.

@SimonBoothroyd
Copy link

If pip is a sticking point (it seems like there are a few several year old feature requests to support this kind of thing!) could we maybe just punt on that for now, and just split on conda?

That way both ecosystems would still have an identical pymbar package available, and then if people don't want jax for now they'll need to either install themselves or get from conda?

@Lnaden
Copy link
Contributor Author

Lnaden commented May 26, 2023

I wonder for pymbar 4. Is it like without JAX, one cannot run mbar at all or it is just slower?

It just runs slower. We use JAX's JIT for all the acceleration and just provide a passthrough in code the user never sees if JAX is unavailable.

could we maybe just punt on that for now, and just split on conda?

Possibly yes-ish. The main problem is we have dependencies specified in the setup.py file and if we dont very carefully tell conda how to build it, conda-build will yell at us for installing dependencies through pip that were not resolved through conda-build. I think its doable though without too much issue.

I also think that there seems to be several people leaning towards optional or opt-in style JAX based on the comments here, so it may be worth doing both anyways.

@xiki-tempula
Copy link

xiki-tempula commented May 26, 2023

It just runs slower. We use JAX's JIT for all the acceleration and just provide a passthrough in code the user never sees if JAX is unavailable.

If this is the case, I don't think there is need for a different version? I'm worried that without JAX, one cannot use mbar. But if it is just a slow issue, Just print a warning when using it without JAX shall be fine. When deploying the package, have both pip and conda without JAX and just document it shall be fine?

@mikemhenry
Copy link
Contributor

Hmmm, I do think that maybe at run time we do just issue a warning if JAX isn't installed and say things could be faster, but then have a pymbar-jax conda-forge package and a pip install pymbar[jax] setup.

I do like it when we make it easy for users to do the right thing, but perhaps in this case since no functionality is lost, a warning is enough?

@mrshirts
Copy link
Collaborator

I would say that mbar w/o jax is fully functional, so issuing a warning saying "Jax not installed, could be faster" or something like that is a reasonable compromise. It should be a warning that is issued every time - I don't want people publishing papers about how pymbar is bad because it's not accelerated (it happens).

I don't ever see a continual forking of packages - I think that there might be a "takes more to install is faster" package and a default, but never more than that.

@mikemhenry
Copy link
Contributor

I would say that mbar w/o jax is fully functional, so issuing a warning saying "Jax not installed, could be faster" or something like that is a reasonable compromise. It should be a warning that is issued every time - I don't want people publishing papers about how pymbar is bad because it's not accelerated (it happens).

I would envision the warning being printed anytime we hit a code path in the lib that could be accelerated.

Lnaden added a commit to Lnaden/pymbar that referenced this issue Jun 12, 2023
This PR makes JAX an optional dependency at the pip level and raises a warning if JAX could not be imported to tell users it could be faster.

On the `pip` side, there will be 2 different instal options:

* `pip install pymbar` will install without JAX
* `pip install pymbar[jax]` will install with JAX.

Partial progress on choderalab#500. Until the [conda feedstock](conda-forge/pymbar-feedstock#34) issue is resolved with a new recipe, I want to leave choderalab#500 open to keep documenting it.

Note: there is an outstanding TODO on the `README.md` file to fill in once we have an official name for the lite version on conda-forge.
@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 12, 2023

I've made #503 to turn JAX into an optional dependency.

A warning is issued every time the import block is called from the mbar_solvers code and should be fairly large and visible, but I cannot say if it would be "obvious" enough. I decided to only issue on import rather than every time we could JIT since I think that would be far too annoying to users. Feedback on this idea is welcome.

I went with the pymbar[jax] and pymbar (no jax) scheme for pip.

For conda-forge, I should be able to handle that fully inside its feedstock outside of this repo. The main question is the naming scheme. How does this sound?

  • Conda pymbar-core maps to PyPI's pymbar (no JAX)
  • Conda pymbar maps to PyPI's pymbar[jax] (yes JAX)

it does create the small conflation point where pymbar is technically a slightly different dependency set on Conda instead of PyPI. But I feel thats okay. Thoughts?

@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 12, 2023

The warning looks like this: https://github.com/choderalab/pymbar/actions/runs/5245393269/jobs/9472772645?pr=503#step:6:199

I am going to add a new line at the start so it aligns a bit better.

@xiki-tempula
Copy link

xiki-tempula commented Jun 12, 2023

For conda-forge, can we do
pymbar=*=*core to get the no jax build?
pymbar=*=*jax will be the default and gets you the JAX build?

An example would be Gromacs, https://anaconda.org/conda-forge/gromacs/files
Where one can do gromacs=*=*cuda* to get the cuda build and gromacs=*=*openmpi* to get the openmpi build.

@mrshirts
Copy link
Collaborator

The warning looks fine. I'll defer to others on the best way to install.

@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 12, 2023

For conda-forge, can we do...

Yes we could, and we would have to specify a build string more exactly like they do in the gromacs feedstock. I would argue using the package specifier of: {package name}={version}={build selector} like you have for pymbar=*=*core is far less common for most users than just {package name}. I also don't know how to specify an actual default when you have multiple build versions unless the user really knows what they want/have like they might for GROMAS where they'll know if they have gpu/cpu options. The advantage to this is everything is all under one package, e.g. "conda-forge/gromacs", on the cloud website and search.

Since we only have standard MBAR and accelerated MBAR, I think adding the burden of users to know {package name}={version}={build selector} and then what selector they want might be more than we're looking for, especially if we want users to just default to conda install pymbar and get the accelerated version.

The way I am proposing is to have multiple outputs like the feedstock here where we (maintainers) don't have to think about build strings or getting conda-forge to build all the versions through specification, and then the users dont have to think at all unless they specifically don't want JAX. We can even make it so that pymbar requires pymbar-core so the users who don't know there is a lighter weight install will see that and at least have a breadcrumb. The downside to this is it does create two separate packages in the cloud and search rather than one.

If people have strong opinions on either choice, I'm happy to implement either.

@xiki-tempula
Copy link

I don't have a strong preference. I'm happy with pymbar-core/pymbar. I kind of like build variant approach but if pymbar-core/pymbar is easy for the user and developers, I don't have objection to that.

@SimonBoothroyd
Copy link

Thanks @Lnaden, the warning also LGTM!

How does this sound?

  • Conda pymbar-core maps to PyPI's pymbar (no JAX)
  • Conda pymbar maps to PyPI's pymbar[jax] (yes JAX)

+1 to this approach as it seems to be a reasonably common pattern on conda when omitting chunky dependencies (e.g. matplotlib-base)

@xiki-tempula
Copy link

I think the warning looks good as well.

@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 12, 2023

@xiki-tempula I thought more about it and there's nothing saying we won't wind up with both in the long run. If we get more/different accelerators as options in the future, having a pymbar-core which has no accelerators and then pymbar=*=*{accelerator}-like syntax might be something that comes up.

@xiki-tempula
Copy link

I think this is a really good idea.

@mrshirts
Copy link
Collaborator

Though probably my instinct would be to have a no-accelerator version conda version, and a single accelerator version that uses relatively common and not0weird accelerators. We could potentially have alternate accelerators just on pip

@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 13, 2023

I've merged the PR, I'm going to keep this issue open until we have the Conda-Forge release cut. I'm also going to resolve some other quick before a new release as separate PRs.

@SimonBoothroyd
Copy link

@Lnaden is there anything I can do to help unblock this?

@Lnaden
Copy link
Contributor Author

Lnaden commented Aug 11, 2023

I was hoping to get #509 in before this, but that's blocked as I want other people's reviews for it. I think I'll just cut a 4.0.2 for optional JAX and move forward with the other parts later. Let me work on that, I should be able to get something up today at least to PyPI

@Lnaden
Copy link
Contributor Author

Lnaden commented Aug 11, 2023

Actually, there is one fragment I want to get in to fix a SciPy >=1.9 bug from #509, then'll release, should be pretty quick.

@Lnaden
Copy link
Contributor Author

Lnaden commented Aug 11, 2023

Okay. Pymbar 4.0.2 is on PyPI.

Without JAX: pip install pymbar
With JAX: pip install pymbar[jax]

Working through Conda-Forge version now.

@Lnaden
Copy link
Contributor Author

Lnaden commented Aug 11, 2023

@SimonBoothroyd The Conda-Forge version has been merged and is propagating out, I'll keep an eye on it and this issue open til I'm sure its fully out there.

That said, I think this issue is resolved for all intent and purpose. If you're using the PyPI version, its out there now and ready to go if you want a lighter version of PyMBAR now.

@Lnaden
Copy link
Contributor Author

Lnaden commented Aug 11, 2023

And with that Conda-forge has:

Without JAX: conda install pymbar-core
With JAX: conda install pymbar

I think that fully resolves this issue! I;m going to close it but can always reopen if there is a need.

Thank you everyone for your contributions to this!

@Lnaden Lnaden closed this as completed Aug 11, 2023
@SimonBoothroyd
Copy link

Amazing - thanks so much @Lnaden !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants