Skip to content

Commit

Permalink
Make JAX an optional dependency
Browse files Browse the repository at this point in the history
This PR makes JAX an optional dependency at the pip level and raises a warning if JAX could not be imported to tell users it could be faster.

On the `pip` side, there will be 2 different instal options:

* `pip install pymbar` will install without JAX
* `pip install pymbar[jax]` will install with JAX.

Partial progress on choderalab#500. Until the [conda feedstock](conda-forge/pymbar-feedstock#34) issue is resolved with a new recipe, I want to leave choderalab#500 open to keep documenting it.

Note: there is an outstanding TODO on the `README.md` file to fill in once we have an official name for the lite version on conda-forge.
  • Loading branch information
Lnaden committed Jun 12, 2023
1 parent 2740798 commit 783cfc6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 16 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,21 @@ The easiest way to install the `pymbar` release is via [conda](http://conda.pyda
conda install -c conda-forge pymbar
```

You can also install `pymbar` from the [Python package index](https://pypi.python.org/pypi/pymbar) using `pip`:
TODO: Add notes about a pymbar-core for non-jax acceleration.

You can also install JAX accelerated `pymbar` from the [Python package index](https://pypi.python.org/pypi/pymbar)
using `pip`:

```bash
pip install pymbar[jax]
```
or the non-jax-accelerated version (which is smaller in dependencies) with
```bash
pip install pymbar
```
Whether you install the JAX accelerated or non-JAX-accelerated version does not
change any calls or how the code is run.


The development version can be installed directly from github via `pip`:

Expand Down
40 changes: 27 additions & 13 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,34 @@
try:
#### JAX related imports
if force_no_jax:
# Capture user-disabled JAX instead "JAX not found"
raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py")
from jax.config import config

config.update("jax_enable_x64", True)

from jax.numpy import exp, sum, newaxis, diag, dot, s_
from jax.numpy import pad as npad
from jax.numpy.linalg import lstsq
import jax.scipy.optimize as optimize_maybe_jax
from jax.scipy.special import logsumexp

from jax import jit as jit_or_passthrough

use_jit = True
try:
from jax.config import config

config.update("jax_enable_x64", True)

from jax.numpy import exp, sum, newaxis, diag, dot, s_
from jax.numpy import pad as npad
from jax.numpy.linalg import lstsq
import jax.scipy.optimize as optimize_maybe_jax
from jax.scipy.special import logsumexp

from jax import jit as jit_or_passthrough

use_jit = True
except ImportError:
# Catch no JAX and throw a warning
warnings.warn("********* JAX NOT FOUND *********\n"
" PyMBAR can run faster with JAX \n"
" But will work fine without it \n"
"Either install with pip or conda:\n"
" pip install pybar[jax] \n"
" OR \n"
" conda install pymbar \n"
"*********************************"
)
raise # Continue with the raised Import Error

except ImportError:
# No JAX found, overlap imports
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
install_requires=["numpy>=1.12",
"scipy",
"numexpr",
"jaxlib;platform_system!='Windows'",
"jax;platform_system!='Windows'"
],
extras_require={
"jax": ["jaxlib;platform_system!='Windows'",
"jax;platform_system!='Windows'"
],
},
)

0 comments on commit 783cfc6

Please sign in to comment.