Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make JAX an optional dependency #503

Merged
merged 3 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@ The easiest way to install the `pymbar` release is via [conda](http://conda.pyda
```bash
conda install -c conda-forge pymbar
```
which will come with JAX to speed up the code. Or to get the non-JAX accelerated version:
```bash
conda install -c conda-forge pymbar-core
```

You can also install `pymbar` from the [Python package index](https://pypi.python.org/pypi/pymbar) using `pip`:

You can also install JAX accelerated `pymbar` from the [Python package index](https://pypi.python.org/pypi/pymbar)
using `pip`:
```bash
pip install pymbar[jax]
```
or the non-jax-accelerated version with
```bash
pip install pymbar
```
Whether you install the JAX accelerated or non-JAX-accelerated version does not
change any calls or how the code is run. The non-Jax version is smaller on disk due to smaller
dependencies, but may not run as fast.


The development version can be installed directly from github via `pip`:

Expand Down
2 changes: 0 additions & 2 deletions examples/harmonic-oscillators/harmonic-oscillators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def stddev_away(namex, errorx, dx):


def get_analytical(beta, K, O, observables):

# For a harmonic oscillator with spring constant K,
# x ~ Normal(x_0, sigma^2), where sigma = 1/sqrt(beta K)

Expand Down Expand Up @@ -670,7 +669,6 @@ def get_analytical(beta, K, O, observables):
def generate_fes_data(
ndim=1, nbinsperdim=15, nsamples=1000, K0=20.0, Ku=100.0, gridscale=0.2, xrange=((-3, 3),)
):

x0 = np.zeros([ndim], np.float64) # center of base potential
numbrellas = 1
nperdim = np.zeros([ndim], int)
Expand Down
3 changes: 1 addition & 2 deletions examples/heat-capacity/heat-capacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pymbar # for MBAR analysis
from pymbar import timeseries # for timeseries analysis


# ===================================================================================================
# INPUT PARAMETERS
# ===================================================================================================
Expand Down Expand Up @@ -175,7 +176,6 @@ def read_simulation_temps(pathname, num_temps):


def print_results(string, E, dE, Cv, dCv, types):

print(string)
print("Temperature dA <E> +/- d<E> ", end=" ")
for t in types:
Expand Down Expand Up @@ -403,7 +403,6 @@ def print_results(string, E, dE, Cv, dCv, types):

# only loop over the points that will be plotted, not the ones that
for i in range(originalK, K):

# Now, calculae heat capacity by T-differences
im = i - 1
ip = i + 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def dddeltag(c, scalef=1, n=nspline):

feses = {}
for methodfull in methods:

# create a fresh copy of the initialized fes object. Operate on that within the loop.
# do the deepcopy here since there seem to be issues if it's done after data is added
# For example, the scikit-learn kde object fails to deepopy.
Expand All @@ -302,7 +301,6 @@ def dddeltag(c, scalef=1, n=nspline):
)

if method == "kde":

kde_parameters = {}
# set the sigma for the spline.
kde_parameters["bandwidth"] = 0.5 * ((chi_max - chi_min) / nbins)
Expand All @@ -315,7 +313,6 @@ def dddeltag(c, scalef=1, n=nspline):
f_i_kde = results["f_i"] # kde results

if method in ["unbiased", "biased", "simple"]:

spline_parameters = {}
if method == "unbiased":
spline_parameters["spline_weights"] = "unbiasedstate"
Expand Down
6 changes: 2 additions & 4 deletions pymbar/confidenceintervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@


def order_replicates(replicates, K):

"""
TODO: Add description for this function and types for parameters

Expand All @@ -56,7 +55,7 @@ def order_replicates(replicates, K):
sigma += sigmacorr

yi = []
for (replicate_index, replicate) in enumerate(replicates):
for replicate_index, replicate in enumerate(replicates):
yi.append(replicate["error"] / sigma)
yiarray = np.asarray(yi)
sortedyi = np.zeros(np.shape(yiarray))
Expand All @@ -76,7 +75,6 @@ def order_replicates(replicates, K):


def anderson_darling(replicates, K):

"""
TODO: Description here

Expand Down Expand Up @@ -300,7 +298,7 @@ def generate_confidence_intervals(replicates, K):
b = 1.0
# how many dimensions in the data?

for (replicate_index, replicate) in enumerate(replicates):
for replicate_index, replicate in enumerate(replicates):
# Compute fraction of free energy differences where error <= alpha sigma
# We only count differences where the analytical difference is larger than a cutoff, so that the results will not be limited by machine precision.
if dim == 0:
Expand Down
20 changes: 0 additions & 20 deletions pymbar/fes.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def generate_fes(
n_bootstraps=0,
seed=-1,
):

"""
Given an intialized MBAR object, a set of points,
the desired energies at that point, and a method, generate
Expand Down Expand Up @@ -439,7 +438,6 @@ def generate_fes(
return result_vals # should we return results under some other conditions?

def _setup_fes_histogram(self, histogram_parameters):

"""
Does initial processsing of histogram_parameters

Expand Down Expand Up @@ -476,7 +474,6 @@ def _setup_fes_histogram(self, histogram_parameters):
self.histogram_datas = None

def _generate_fes_histogram(self, b, x_n, w_nb, log_w_nb):

"""
Parameters
----------
Expand Down Expand Up @@ -603,7 +600,6 @@ def _generate_fes_histogram(self, b, x_n, w_nb, log_w_nb):
self.histogram_datas.append(histogram_data)

def _setup_fes_kde(self, kde_parameters):

"""
Does initial processsing of kde_parameters

Expand Down Expand Up @@ -652,7 +648,6 @@ def _setup_fes_kde(self, kde_parameters):
self.kde = kde

def _generate_fes_kde(self, b, x_n, w_n):

"""
Given an fes object with the kde data set up, determine
the information necessary to define a FES using a kernel density approximation
Expand Down Expand Up @@ -704,7 +699,6 @@ def _generate_fes_kde(self, b, x_n, w_n):
self.kdes.append(kde)

def _setup_fes_spline(self, spline_parameters):

"""
Does initial processsing of spline_parameters

Expand Down Expand Up @@ -813,7 +807,6 @@ def _setup_fes_spline(self, spline_parameters):
self.fes_functions = None

def _get_initial_spline_points(self):

"""
Uses information from spline_parameters to construct initial
points to create a spline frmo which to start the minimization.
Expand Down Expand Up @@ -888,7 +881,6 @@ def _get_initial_spline_points(self):
return xinit, yinit

def _get_initial_spline(self, xinit, yinit):

"""
Uses information from spline_parameters to construct initial
points to create a spline frmo which to start the minimization.
Expand Down Expand Up @@ -977,7 +969,6 @@ def _get_initial_spline(self, xinit, yinit):
return spline_data

def _generate_fes_spline(self, b, x_n, w_n):

"""
Given an fes object with the spline set up, determine
the information necessary to define a FES.
Expand Down Expand Up @@ -1046,7 +1037,6 @@ def _generate_fes_spline(self, b, x_n, w_n):
firsttime = True

while dg > tol: # until we reach the tolerance.

f = func(xi, *spline_args)

# we need some error handling: if we stepped too far, we should go back
Expand Down Expand Up @@ -1109,7 +1099,6 @@ def _generate_fes_spline(self, b, x_n, w_n):

@staticmethod
def _calculate_information_criteria(nparameters, minus_log_likelihood, N):

"""
Calculate and store various informaton criterias

Expand Down Expand Up @@ -1369,7 +1358,6 @@ def _get_fes_histogram(
raise ParameterError("Specified reference point for FES not given")

if reference_point in ["from-lowest", "from-specified", "all-differences"]:

if reference_point == "from-lowest":
# Determine free energy with lowest free energy to serve as reference point
j = histogram_data["f"].argmin()
Expand Down Expand Up @@ -1597,7 +1585,6 @@ def _get_fes_kde(
df_i = None

elif uncertainty_method == "bootstrap":

if self.kdes is None:
raise ParameterError(
f"Cannot calculate bootstrap error of boostrap KDE's not determined"
Expand Down Expand Up @@ -1867,7 +1854,6 @@ def prob(x):
self.mc_data["g"] = guse # statistical efficiency used for subsampling

def get_confidence_intervals(self, xplot, plow, phigh, reference="zero"):

"""
Parameters
----------
Expand Down Expand Up @@ -1937,7 +1923,6 @@ def get_confidence_intervals(self, xplot, plow, phigh, reference="zero"):
return return_vals

def get_mc_data(self):

"""convenience function to retrieve MC data

Parameters
Expand All @@ -1964,7 +1949,6 @@ def get_mc_data(self):
return self.mc_data

def _get_MC_loglikelihood(self, x_n, w_n, spline_weights, spline, xrange):

"""
Parameters
----------
Expand Down Expand Up @@ -2023,7 +2007,6 @@ def expk(x, kf):
return loglikelihood

def _MC_step(self, x_n, w_n, stepsize, xrange, spline_weights, logprior):

"""sample over the posterior space of the FES as splined.

Parameters
Expand Down Expand Up @@ -2114,7 +2097,6 @@ def prob(x):
return results

def _bspline_calculate_f(self, xi, x_n, w_n):

"""Calculate the maximum likelihood / KL divergence of the FES represented using B-splines.

Parameters
Expand Down Expand Up @@ -2321,7 +2303,6 @@ def dexpf(x, index):
return g

def _bspline_calculate_h(self, xi, x_n, w_n):

"""Calculate the Hessian of the maximum likelihood / KL divergence of the FES represented using B-splines.

Parameters
Expand Down Expand Up @@ -2411,7 +2392,6 @@ def ddexpf(x, index_i, index_j):
for i in range(nspline - 1):
for j in range(0, i + 1):
if np.abs(i - j) <= kdegree:

# now compute the expectation of each derivative
pE = self._integrate(
ddexpf,
Expand Down
5 changes: 1 addition & 4 deletions pymbar/mbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
JAX_SOLVER_PROTOCOL = mbar_solvers.JAX_SOLVER_PROTOCOL
BOOTSTRAP_SOLVER_PROTOCOL = mbar_solvers.BOOTSTRAP_SOLVER_PROTOCOL


# =========================================================================
# MBAR class definition
# =========================================================================
Expand Down Expand Up @@ -364,7 +365,6 @@ def __init__(
protocols = {pnames[0]: solver_protocol, pnames[1]: bootstrap_solver_protocol}

for defl, rob, pname in zip(defaults, robusts, pnames):

prot = protocols[pname]
if prot is None or prot == "default":
prot = defl
Expand Down Expand Up @@ -972,7 +972,6 @@ def compute_expectations_inner(
A_n[i, :] = A_n[i, :] + (A_min[i] - logfactors[i])

if return_theta:

# Note: these variances will be the same whether or not we
# subtract a different constant from each A_i
# for efficency, output theta in block form
Expand Down Expand Up @@ -1029,7 +1028,6 @@ def compute_expectations_inner(

# =========================================================================
def compute_covariance_of_sums(self, d_ij, K, a):

"""
We wish to calculate the variance of a weighted sum of free energy differences.
for example ``var(\\sum a_i df_i)``.
Expand Down Expand Up @@ -1919,7 +1917,6 @@ def _computeUnnormalizedLogWeights(self, u_n):
return -1.0 * logsumexp(self.f_k + u_n[:, np.newaxis] - self.u_kn.T, b=self.N_k, axis=1)

def _initialize_with_bar(self, u_kn, f_k_init=None):

"""

Internal method for intializing free energies simulations with BAR.
Expand Down
44 changes: 29 additions & 15 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,36 @@
try:
#### JAX related imports
if force_no_jax:
# Capture user-disabled JAX instead "JAX not found"
raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py")
from jax.config import config

config.update("jax_enable_x64", True)

from jax.numpy import exp, sum, newaxis, diag, dot, s_
from jax.numpy import pad as npad
from jax.numpy.linalg import lstsq
import jax.scipy.optimize as optimize_maybe_jax
from jax.scipy.special import logsumexp

from jax import jit as jit_or_passthrough

use_jit = True
try:
from jax.config import config

config.update("jax_enable_x64", True)

from jax.numpy import exp, sum, newaxis, diag, dot, s_
from jax.numpy import pad as npad
from jax.numpy.linalg import lstsq
import jax.scipy.optimize as optimize_maybe_jax
from jax.scipy.special import logsumexp

from jax import jit as jit_or_passthrough

use_jit = True
except ImportError:
# Catch no JAX and throw a warning
warnings.warn(
"\n"
"********* JAX NOT FOUND *********\n"
" PyMBAR can run faster with JAX \n"
" But will work fine without it \n"
"Either install with pip or conda:\n"
" pip install pybar[jax] \n"
" OR \n"
" conda install pymbar \n"
"*********************************"
)
raise # Continue with the raised Import Error

except ImportError:
# No JAX found, overlap imports
Expand Down Expand Up @@ -431,7 +447,6 @@ def mbar_W_nk(u_kn, N_k, f_k):


def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None):

"""
Determine dimensionless free energies by a combination of Newton-Raphson iteration and self-consistent iteration.
Picks whichever method gives the lowest gradient.
Expand Down Expand Up @@ -497,7 +512,6 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None):
min_sc_iter = options["min_sc_iter"]
warn = "Did not converge."
for iteration in range(0, maxiter):

if use_jit:
(f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr) = jax_core_adaptive(
u_kn, N_k, f_k, options["gamma"]
Expand Down
Loading