Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn about JAX bitsize changes #504

Merged
merged 2 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ where `results` is a dictionary with keys `mu`, `sigma`, and `Theta`, where `mu[

See the docstring help for these individual methods for more information on exact usage; in Python or IPython, you can view the docstrings with `help()`.

JAX needs 64-bit mode
---------------------
PyMBAR needs 64-bit floats to provide reliable answers. JAX by default uses
[32-bit (Single) bitsize](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision).
PyMBAR will turn on JAX's 64-bit mode, which may cause issues with some separate uses of JAX in the same code as PyMBAR,
such as existing Neural Network (NN) Models for machine learning.

Authors
-------
* Kyle A. Beauchamp <[email protected]>
Expand Down
68 changes: 54 additions & 14 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import scipy.optimize
from pymbar.utils import ensure_type, check_w_normalized, ParameterError

logger = logging.getLogger(__name__)

use_jit = False
force_no_jax = False # Temporary until we can make a proper setting to enable/disable by choice
try:
Expand All @@ -17,7 +19,21 @@
try:
from jax.config import config

config.update("jax_enable_x64", True)
if not config.x64_enabled:
# Warn that we're going to be setting 64 bit jax
logger.warning(
"\n"
"****** PyMBAR will use 64-bit JAX! *******\n"
"* JAX is currently set to 32-bit bitsize *\n"
"* which is its default. *\n"
"* *\n"
"* PyMBAR requires 64-bit mode and WILL *\n"
"* enable JAX's 64-bit mode when called. *\n"
"* *\n"
"* This MAY cause problems with other *\n"
"* Uses of JAX in the same code. *\n"
"******************************************\n"
)

from jax.numpy import exp, sum, newaxis, diag, dot, s_
from jax.numpy import pad as npad
Expand All @@ -30,7 +46,7 @@
use_jit = True
except ImportError:
# Catch no JAX and throw a warning
warnings.warn(
logger.warning(
"\n"
"********* JAX NOT FOUND *********\n"
" PyMBAR can run faster with JAX \n"
Expand Down Expand Up @@ -61,14 +77,13 @@ def jit_or_passthrough(fn):
# Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax.
# Can be mostly ignored

logger = logging.getLogger(__name__)

if use_jit is False:
logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy")
else:
logger.info("JAX detected. Using JAX acceleration.")

# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving the MBAR equations.
# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving
# the MBAR equations.
# Note: we use tuples instead of lists to avoid accidental mutability.
JAX_SOLVER_PROTOCOL = (
dict(method="BFGS", continuation=True),
Expand Down Expand Up @@ -110,6 +125,31 @@ def jit_or_passthrough(fn):
scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included


def jit_or_pass_after_bitsize(jitable_fn):
"""
Attempt to set JAX precision if present. This does nothing if JAX is not present

Parameters
----------
jitable_fn: function
A function which can be jit'd
"""

# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
# Warn that JAX 64-bit will being turned on
logger.warning(
"\n"
"******* JAX 64-bit mode is now on! *******\n"
"* JAX is now set to 64-bit mode! *\n"
"* This MAY cause problems with other *\n"
"* uses of JAX in the same code. *\n"
"******************************************\n"
)
config.update("jax_enable_x64", True)
return jit_or_passthrough(jitable_fn)


def validate_inputs(u_kn, N_k, f_k):
"""Check types and return inputs for MBAR calculations.

Expand Down Expand Up @@ -167,7 +207,7 @@ def self_consistent_update(u_kn, N_k, f_k, states_with_samples=None):
return jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=states_with_samples)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def _jit_self_consistent_update(u_kn, N_k, f_k):
"""JAX version of self_consistent update. For parameters, see self_consistent_update.
N_k must be float (should be cast at a higher level)
Expand Down Expand Up @@ -220,7 +260,7 @@ def mbar_gradient(u_kn, N_k, f_k):
return jax_mbar_gradient(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_gradient(u_kn, N_k, f_k):
"""JAX version of MBAR gradient function. See documentation of mbar_gradient.
N_k must be float (should be cast at a higher level)
Expand Down Expand Up @@ -263,7 +303,7 @@ def mbar_objective(u_kn, N_k, f_k):
return jax_mbar_objective(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_objective(u_kn, N_k, f_k):
"""JAX version of mbar_objective.
For parameters, mbar_objective_and_Gradient
Expand All @@ -277,7 +317,7 @@ def jax_mbar_objective(u_kn, N_k, f_k):
return obj


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_objective_and_gradient(u_kn, N_k, f_k):
"""JAX version of mbar_objective_and_gradient.
For parameters, mbar_objective_and_Gradient
Expand Down Expand Up @@ -331,7 +371,7 @@ def mbar_objective_and_gradient(u_kn, N_k, f_k):
return jax_mbar_objective_and_gradient(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_hessian(u_kn, N_k, f_k):
"""JAX version of mbar_hessian.
For parameters, see mbar_hessian
Expand Down Expand Up @@ -375,7 +415,7 @@ def mbar_hessian(u_kn, N_k, f_k):
return jax_mbar_hessian(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_log_W_nk(u_kn, N_k, f_k):
"""JAX version of mbar_log_W_nk.
For parameters, see mbar_log_W_nk
Expand Down Expand Up @@ -412,7 +452,7 @@ def mbar_log_W_nk(u_kn, N_k, f_k):
return jax_mbar_log_W_nk(u_kn, N_k, f_k)


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_mbar_W_nk(u_kn, N_k, f_k):
"""JAX version of mbar_W_nk.
For parameters, see mbar_W_nk
Expand Down Expand Up @@ -606,7 +646,7 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None):
return results


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_core_adaptive(u_kn, N_k, f_k, gamma):
"""JAX version of adaptive inner loop.
N_k must be float (should be cast at a higher level)
Expand All @@ -633,7 +673,7 @@ def jax_core_adaptive(u_kn, N_k, f_k, gamma):
return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr


@jit_or_passthrough
@jit_or_pass_after_bitsize
def jax_precondition_u_kn(u_kn, N_k, f_k):
"""JAX version of precondition_u_kn
for parameters, see precondition_u_kn
Expand Down