diff --git a/README.md b/README.md index b2054aae..2005c383 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,13 @@ where `results` is a dictionary with keys `mu`, `sigma`, and `Theta`, where `mu[ See the docstring help for these individual methods for more information on exact usage; in Python or IPython, you can view the docstrings with `help()`. +JAX needs 64-bit mode +--------------------- +PyMBAR needs 64-bit floats to provide reliable answers. JAX by default uses +[32-bit (Single) bitsize](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). +PyMBAR will turn on JAX's 64-bit mode, which may cause issues with some separate uses of JAX in the same code as PyMBAR, +such as existing Neural Network (NN) Models for machine learning. + Authors ------- * Kyle A. Beauchamp diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index 31a24ccd..9ef56679 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -7,6 +7,8 @@ import scipy.optimize from pymbar.utils import ensure_type, check_w_normalized, ParameterError +logger = logging.getLogger(__name__) + use_jit = False force_no_jax = False # Temporary until we can make a proper setting to enable/disable by choice try: @@ -17,7 +19,21 @@ try: from jax.config import config - config.update("jax_enable_x64", True) + if not config.x64_enabled: + # Warn that we're going to be setting 64 bit jax + logger.warning( + "\n" + "****** PyMBAR will use 64-bit JAX! *******\n" + "* JAX is currently set to 32-bit bitsize *\n" + "* which is its default. *\n" + "* *\n" + "* PyMBAR requires 64-bit mode and WILL *\n" + "* enable JAX's 64-bit mode when called. *\n" + "* *\n" + "* This MAY cause problems with other *\n" + "* Uses of JAX in the same code. *\n" + "******************************************\n" + ) from jax.numpy import exp, sum, newaxis, diag, dot, s_ from jax.numpy import pad as npad @@ -30,7 +46,7 @@ use_jit = True except ImportError: # Catch no JAX and throw a warning - warnings.warn( + logger.warning( "\n" "********* JAX NOT FOUND *********\n" " PyMBAR can run faster with JAX \n" @@ -61,14 +77,13 @@ def jit_or_passthrough(fn): # Known issue with astroid<2.12 and numpy array returns, but 2.12 doesn't fix it due to returns being jax. # Can be mostly ignored -logger = logging.getLogger(__name__) - if use_jit is False: logger.info("JAX was either not detected or disabled, using standard NumPy and SciPy") else: logger.info("JAX detected. Using JAX acceleration.") -# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving the MBAR equations. +# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving +# the MBAR equations. # Note: we use tuples instead of lists to avoid accidental mutability. JAX_SOLVER_PROTOCOL = ( dict(method="BFGS", continuation=True), @@ -110,6 +125,31 @@ def jit_or_passthrough(fn): scipy_root_options = ["hybr", "lm"] # only use root options with the hessian included +def jit_or_pass_after_bitsize(jitable_fn): + """ + Attempt to set JAX precision if present. This does nothing if JAX is not present + + Parameters + ---------- + jitable_fn: function + A function which can be jit'd + """ + + # This will only trigger if JAX is set + if use_jit and not config.x64_enabled: + # Warn that JAX 64-bit will being turned on + logger.warning( + "\n" + "******* JAX 64-bit mode is now on! *******\n" + "* JAX is now set to 64-bit mode! *\n" + "* This MAY cause problems with other *\n" + "* uses of JAX in the same code. *\n" + "******************************************\n" + ) + config.update("jax_enable_x64", True) + return jit_or_passthrough(jitable_fn) + + def validate_inputs(u_kn, N_k, f_k): """Check types and return inputs for MBAR calculations. @@ -167,7 +207,7 @@ def self_consistent_update(u_kn, N_k, f_k, states_with_samples=None): return jax_self_consistent_update(u_kn, N_k, f_k, states_with_samples=states_with_samples) -@jit_or_passthrough +@jit_or_pass_after_bitsize def _jit_self_consistent_update(u_kn, N_k, f_k): """JAX version of self_consistent update. For parameters, see self_consistent_update. N_k must be float (should be cast at a higher level) @@ -220,7 +260,7 @@ def mbar_gradient(u_kn, N_k, f_k): return jax_mbar_gradient(u_kn, N_k, f_k) -@jit_or_passthrough +@jit_or_pass_after_bitsize def jax_mbar_gradient(u_kn, N_k, f_k): """JAX version of MBAR gradient function. See documentation of mbar_gradient. N_k must be float (should be cast at a higher level) @@ -263,7 +303,7 @@ def mbar_objective(u_kn, N_k, f_k): return jax_mbar_objective(u_kn, N_k, f_k) -@jit_or_passthrough +@jit_or_pass_after_bitsize def jax_mbar_objective(u_kn, N_k, f_k): """JAX version of mbar_objective. For parameters, mbar_objective_and_Gradient @@ -277,7 +317,7 @@ def jax_mbar_objective(u_kn, N_k, f_k): return obj -@jit_or_passthrough +@jit_or_pass_after_bitsize def jax_mbar_objective_and_gradient(u_kn, N_k, f_k): """JAX version of mbar_objective_and_gradient. For parameters, mbar_objective_and_Gradient @@ -331,7 +371,7 @@ def mbar_objective_and_gradient(u_kn, N_k, f_k): return jax_mbar_objective_and_gradient(u_kn, N_k, f_k) -@jit_or_passthrough +@jit_or_pass_after_bitsize def jax_mbar_hessian(u_kn, N_k, f_k): """JAX version of mbar_hessian. For parameters, see mbar_hessian @@ -375,7 +415,7 @@ def mbar_hessian(u_kn, N_k, f_k): return jax_mbar_hessian(u_kn, N_k, f_k) -@jit_or_passthrough +@jit_or_pass_after_bitsize def jax_mbar_log_W_nk(u_kn, N_k, f_k): """JAX version of mbar_log_W_nk. For parameters, see mbar_log_W_nk @@ -412,7 +452,7 @@ def mbar_log_W_nk(u_kn, N_k, f_k): return jax_mbar_log_W_nk(u_kn, N_k, f_k) -@jit_or_passthrough +@jit_or_pass_after_bitsize def jax_mbar_W_nk(u_kn, N_k, f_k): """JAX version of mbar_W_nk. For parameters, see mbar_W_nk @@ -606,7 +646,7 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None): return results -@jit_or_passthrough +@jit_or_pass_after_bitsize def jax_core_adaptive(u_kn, N_k, f_k, gamma): """JAX version of adaptive inner loop. N_k must be float (should be cast at a higher level) @@ -633,7 +673,7 @@ def jax_core_adaptive(u_kn, N_k, f_k, gamma): return f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr -@jit_or_passthrough +@jit_or_pass_after_bitsize def jax_precondition_u_kn(u_kn, N_k, f_k): """JAX version of precondition_u_kn for parameters, see precondition_u_kn