-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optional JAX dependency #500
Comments
+1 on a I think I could live with the docs saying "we recomend using conda-forge, if you need to use pip, install it like this |
I dont know how to issue a message at all with a pip install, but it would be an easy thing to add a warning/info on Pymbar invocation. I'd be in favor of this.
I'd like to avoid hacky solutions when there are solutions, just not for the particular paradigm I suggested (optional thing removes dependencies). And I very much would like to avoid envar anything. |
I know that alchemlyb is a little leery of jax; @xiki-tempula could you maybe speak up on this and your thoughts? . . . |
We certainly could have a warning at run that it's downloaded without jax, and might be slower, with information on where to get jax. |
I wonder for pymbar 4. Is it like without JAX, one cannot run mbar at all or it is just slower? I think JAX is more of a problem for pip as it is installed every time when one setup a new env. For conda, this is much less of an issue as conda would only need to install JAX once and mamba is much faster than pip. |
If That way both ecosystems would still have an identical |
It just runs slower. We use JAX's JIT for all the acceleration and just provide a passthrough in code the user never sees if JAX is unavailable.
Possibly yes-ish. The main problem is we have dependencies specified in the setup.py file and if we dont very carefully tell conda how to build it, conda-build will yell at us for installing dependencies through pip that were not resolved through conda-build. I think its doable though without too much issue. I also think that there seems to be several people leaning towards optional or opt-in style JAX based on the comments here, so it may be worth doing both anyways. |
If this is the case, I don't think there is need for a different version? I'm worried that without JAX, one cannot use mbar. But if it is just a slow issue, Just print a warning when using it without JAX shall be fine. When deploying the package, have both pip and conda without JAX and just document it shall be fine? |
Hmmm, I do think that maybe at run time we do just issue a warning if JAX isn't installed and say things could be faster, but then have a I do like it when we make it easy for users to do the right thing, but perhaps in this case since no functionality is lost, a warning is enough? |
I would say that mbar w/o jax is fully functional, so issuing a warning saying "Jax not installed, could be faster" or something like that is a reasonable compromise. It should be a warning that is issued every time - I don't want people publishing papers about how pymbar is bad because it's not accelerated (it happens). I don't ever see a continual forking of packages - I think that there might be a "takes more to install is faster" package and a default, but never more than that. |
I would envision the warning being printed anytime we hit a code path in the lib that could be accelerated. |
This PR makes JAX an optional dependency at the pip level and raises a warning if JAX could not be imported to tell users it could be faster. On the `pip` side, there will be 2 different instal options: * `pip install pymbar` will install without JAX * `pip install pymbar[jax]` will install with JAX. Partial progress on choderalab#500. Until the [conda feedstock](conda-forge/pymbar-feedstock#34) issue is resolved with a new recipe, I want to leave choderalab#500 open to keep documenting it. Note: there is an outstanding TODO on the `README.md` file to fill in once we have an official name for the lite version on conda-forge.
I've made #503 to turn JAX into an optional dependency. A warning is issued every time the I went with the For conda-forge, I should be able to handle that fully inside its feedstock outside of this repo. The main question is the naming scheme. How does this sound?
it does create the small conflation point where |
The warning looks like this: https://github.com/choderalab/pymbar/actions/runs/5245393269/jobs/9472772645?pr=503#step:6:199 I am going to add a new line at the start so it aligns a bit better. |
For conda-forge, can we do An example would be Gromacs, https://anaconda.org/conda-forge/gromacs/files |
The warning looks fine. I'll defer to others on the best way to install. |
Yes we could, and we would have to specify a build string more exactly like they do in the gromacs feedstock. I would argue using the package specifier of: Since we only have standard MBAR and accelerated MBAR, I think adding the burden of users to know The way I am proposing is to have multiple outputs like the feedstock here where we (maintainers) don't have to think about build strings or getting conda-forge to build all the versions through specification, and then the users dont have to think at all unless they specifically don't want JAX. We can even make it so that If people have strong opinions on either choice, I'm happy to implement either. |
I don't have a strong preference. I'm happy with |
Thanks @Lnaden, the warning also LGTM!
+1 to this approach as it seems to be a reasonably common pattern on conda when omitting chunky dependencies (e.g. |
I think the warning looks good as well. |
@xiki-tempula I thought more about it and there's nothing saying we won't wind up with both in the long run. If we get more/different accelerators as options in the future, having a |
I think this is a really good idea. |
Though probably my instinct would be to have a no-accelerator version conda version, and a single accelerator version that uses relatively common and not0weird accelerators. We could potentially have alternate accelerators just on pip |
I've merged the PR, I'm going to keep this issue open until we have the Conda-Forge release cut. I'm also going to resolve some other quick before a new release as separate PRs. |
@Lnaden is there anything I can do to help unblock this? |
I was hoping to get #509 in before this, but that's blocked as I want other people's reviews for it. I think I'll just cut a 4.0.2 for optional JAX and move forward with the other parts later. Let me work on that, I should be able to get something up today at least to PyPI |
Actually, there is one fragment I want to get in to fix a SciPy >=1.9 bug from #509, then'll release, should be pretty quick. |
Okay. Pymbar 4.0.2 is on PyPI. Without JAX: Working through Conda-Forge version now. |
@SimonBoothroyd The Conda-Forge version has been merged and is propagating out, I'll keep an eye on it and this issue open til I'm sure its fully out there. That said, I think this issue is resolved for all intent and purpose. If you're using the PyPI version, its out there now and ready to go if you want a lighter version of PyMBAR now. |
And with that Conda-forge has: Without JAX: conda install pymbar-core I think that fully resolves this issue! I;m going to close it but can always reopen if there is a need. Thank you everyone for your contributions to this! |
Amazing - thanks so much @Lnaden ! |
@SimonBoothroyd on the conda-forge feedstock has requested a pymbar where JAX is optional for a leaner install where the full MBAR solver is not needed (for their use case). This can be done with no modifications to the code, only the build, since all the JAX calls are optional and checked at runtime; i think.
This would require a Conda and Pip solution at the same time.
The Conda-forge/build solution is fairly easy with the
outputs
directive allowing multiple builds. I've worked on the QCFractal package which does that and have the example here; it matches this use case pretty precisely.The Pip solution could be solved by making JAX a feature instead of default installed through an Optional Dependency. The naive solution would have
jax
be anextra_requires
keyword, resulting in two commands like:pip install pymbar
pip install pymbar[jax]
both being valid, the later one get JAX.I personally am against this approach because I think the idea is we want JAX installed by default but would like other opinions as well. I don't know of an
extra_requires
(setup.py variant) or aproject.optional-dependencies
(pyproject.toml variant) which blocks/omits primary dependencies. There are other options like having multiple build projects, but I don't want to explore those options if the above is acceptable to everyone else due to increased build/maintenance complexity.The text was updated successfully, but these errors were encountered: