Skip to content

Commit

Permalink
Merge pull request #504 from Lnaden/jax-64-warning
Browse files Browse the repository at this point in the history
Warn about JAX bitsize changes
  • Loading branch information
Lnaden authored Jun 15, 2023
2 parents 86199a1 + dc29e7d commit a5fa114
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 14 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ where `results` is a dictionary with keys `mu`, `sigma`, and `Theta`, where `mu[

See the docstring help for these individual methods for more information on exact usage; in Python or IPython, you can view the docstrings with `help()`.

JAX needs 64-bit mode
---------------------
PyMBAR needs 64-bit floats to provide reliable answers. JAX by default uses
[32-bit (Single) bitsize](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
PyMBAR will turn on JAX's 64-bit mode, which may cause issues with some separate uses of JAX in the same code as PyMBAR,
such as existing Neural Network (NN) Models for machine learning.

Authors
-------
* Kyle A. Beauchamp <[email protected]>
Expand Down
68 changes: 54 additions & 14 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import scipy.optimize
from pymbar.utils import ensure_type, check_w_normalized, ParameterError

logger = logging.getLogger(__name__)

use_jit = False
force_no_jax = False # Temporary until we can make a proper setting to enable/disable by choice
try:
Expand All @@ -17,7 +19,21 @@
try:
from jax.config import config

config.update("jax_enable_x64", True)
if not config.x64_enabled:
# Warn that we're going to be setting 64 bit jax
logger.warning(
"\n"
"****** PyMBAR will use 64-bit JAX! *******\n"
"* JAX is currently set to 32-bit bitsize *\n"
"* which is its default. *\n"
"* *\n"
"* PyMBAR requires 64-bit mode and WILL *\n"
"* enable JAX's 64-bit mode when called. *\n"
"* *\n"
"* This MAY cause problems with other *\n"
"* Uses of JAX in the same code. *\n"
"******************************************\n"
)

from jax.numpy import exp, sum, newaxis, diag, dot, s_
from jax.numpy import pad as npad
Expand All @@ -30,7 +46,7 @@
use_jit = True
except ImportError:
# Catch no JAX and throw a warning
warnings.warn(
logger.warning(
"\n"
"********* JAX NOT FOUND *********\n"
" PyMBAR can run faster with JAX \n"
Expand Down Expand Up @@ -61,14 +77,13 @@ def jit_or_passthrough(fn):
# Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax.
# Can be mostly ignored

logger = logging.getLogger(__name__)

if use_jit is False:
logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy")
else:
logger.info("JAX detected. Using JAX acceleration.")

# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving the MBAR equations.
# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving
# the MBAR equations.
# Note: we use tuples instead of lists to avoid accidental mutability.
JAX_SOLVER_PROTOCOL = (
dict(method="BFGS", continuation=True),
Expand Down Expand Up @@ -110,6 +125,31 @@ def jit_or_passthrough(fn):
scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included


def jit_or_pass_after_bitsize(jitable_fn):
"""
Attempt to set JAX precision if present. This does nothing if JAX is not present
Parameters
----------
jitable_fn: function
A function which can be jit'd
"""

# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
# Warn that JAX 64-bit will being turned on
logger.warning(
"\n"
"******* JAX 64-bit mode is now on! *******\n"
"* JAX is now set to 64-bit mode! *\n"
"* This MAY cause problems with other *\n"
"* uses of JAX in the same code. *\n"
"******************************************\n"
)
config.update("jax_enable_x64", True)
return jit_or_passthrough(jitable_fn)


def validate_inputs(u_kn, N_k, f_k):
"""Check types and return inputs for MBAR calculations.
Expand Down Expand Up @@ -167,7 +207,7 @@ def self_consistent_update(u_kn, N_k, f_k, states_with_samples=None):
return jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=states_with_samples)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def _jit_self_consistent_update(u_kn, N_k, f_k):
"""JAX version of self_consistent update. For parameters, see self_consistent_update.
N_k must be float (should be cast at a higher level)
Expand Down Expand Up @@ -220,7 +260,7 @@ def mbar_gradient(u_kn, N_k, f_k):
return jax_mbar_gradient(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_gradient(u_kn, N_k, f_k):
"""JAX version of MBAR gradient function. See documentation of mbar_gradient.
N_k must be float (should be cast at a higher level)
Expand Down Expand Up @@ -263,7 +303,7 @@ def mbar_objective(u_kn, N_k, f_k):
return jax_mbar_objective(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_objective(u_kn, N_k, f_k):
"""JAX version of mbar_objective.
For parameters, mbar_objective_and_Gradient
Expand All @@ -277,7 +317,7 @@ def jax_mbar_objective(u_kn, N_k, f_k):
return obj


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_objective_and_gradient(u_kn, N_k, f_k):
"""JAX version of mbar_objective_and_gradient.
For parameters, mbar_objective_and_Gradient
Expand Down Expand Up @@ -331,7 +371,7 @@ def mbar_objective_and_gradient(u_kn, N_k, f_k):
return jax_mbar_objective_and_gradient(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_hessian(u_kn, N_k, f_k):
"""JAX version of mbar_hessian.
For parameters, see mbar_hessian
Expand Down Expand Up @@ -375,7 +415,7 @@ def mbar_hessian(u_kn, N_k, f_k):
return jax_mbar_hessian(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_log_W_nk(u_kn, N_k, f_k):
"""JAX version of mbar_log_W_nk.
For parameters, see mbar_log_W_nk
Expand Down Expand Up @@ -412,7 +452,7 @@ def mbar_log_W_nk(u_kn, N_k, f_k):
return jax_mbar_log_W_nk(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_W_nk(u_kn, N_k, f_k):
"""JAX version of mbar_W_nk.
For parameters, see mbar_W_nk
Expand Down Expand Up @@ -606,7 +646,7 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None):
return results


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_core_adaptive(u_kn, N_k, f_k, gamma):
"""JAX version of adaptive inner loop.
N_k must be float (should be cast at a higher level)
Expand All @@ -633,7 +673,7 @@ def jax_core_adaptive(u_kn, N_k, f_k, gamma):
return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_precondition_u_kn(u_kn, N_k, f_k):
"""JAX version of precondition_u_kn
for parameters, see precondition_u_kn
Expand Down

0 comments on commit a5fa114

Please sign in to comment.