diff --git a/README.md b/README.md index 9d92f518..b2054aae 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,24 @@ The easiest way to install the `pymbar` release is via [conda](http://conda.pyda ```bash conda install -c conda-forge pymbar ``` +which will come with JAX to speed up the code. Or to get the non-JAX accelerated version: +```bash +conda install -c conda-forge pymbar-core +``` -You can also install `pymbar` from the [Python package index](https://pypi.python.org/pypi/pymbar) using `pip`: - +You can also install JAX accelerated `pymbar` from the [Python package index](https://pypi.python.org/pypi/pymbar) +using `pip`: +```bash +pip install pymbar[jax] +``` +or the non-jax-accelerated version with ```bash pip install pymbar ``` +Whether you install the JAX accelerated or non-JAX-accelerated version does not +change any calls or how the code is run. The non-Jax version is smaller on disk due to smaller +dependencies, but may not run as fast. + The development version can be installed directly from github via `pip`: diff --git a/examples/harmonic-oscillators/harmonic-oscillators.py b/examples/harmonic-oscillators/harmonic-oscillators.py index 5d08cd95..2d692a19 100644 --- a/examples/harmonic-oscillators/harmonic-oscillators.py +++ b/examples/harmonic-oscillators/harmonic-oscillators.py @@ -39,7 +39,6 @@ def stddev_away(namex, errorx, dx): def get_analytical(beta, K, O, observables): - # For a harmonic oscillator with spring constant K, # x ~ Normal(x_0, sigma^2), where sigma = 1/sqrt(beta K) @@ -670,7 +669,6 @@ def get_analytical(beta, K, O, observables): def generate_fes_data( ndim=1, nbinsperdim=15, nsamples=1000, K0=20.0, Ku=100.0, gridscale=0.2, xrange=((-3, 3),) ): - x0 = np.zeros([ndim], np.float64) # center of base potential numbrellas = 1 nperdim = np.zeros([ndim], int) diff --git a/examples/heat-capacity/heat-capacity.py b/examples/heat-capacity/heat-capacity.py index 329f9dd0..57296e57 100644 --- a/examples/heat-capacity/heat-capacity.py +++ b/examples/heat-capacity/heat-capacity.py @@ -13,6 +13,7 @@ import pymbar # for MBAR analysis from pymbar import timeseries # for timeseries analysis + # =================================================================================================== # INPUT PARAMETERS # =================================================================================================== @@ -175,7 +176,6 @@ def read_simulation_temps(pathname, num_temps): def print_results(string, E, dE, Cv, dCv, types): - print(string) print("Temperature dA +/- d ", end=" ") for t in types: @@ -403,7 +403,6 @@ def print_results(string, E, dE, Cv, dCv, types): # only loop over the points that will be plotted, not the ones that for i in range(originalK, K): - # Now, calculae heat capacity by T-differences im = i - 1 ip = i + 1 diff --git a/examples/umbrella-sampling-fes/umbrella-sampling-advanced-fes.py b/examples/umbrella-sampling-fes/umbrella-sampling-advanced-fes.py index 575ce8ae..ad6a4e5a 100644 --- a/examples/umbrella-sampling-fes/umbrella-sampling-advanced-fes.py +++ b/examples/umbrella-sampling-fes/umbrella-sampling-advanced-fes.py @@ -277,7 +277,6 @@ def dddeltag(c, scalef=1, n=nspline): feses = {} for methodfull in methods: - # create a fresh copy of the initialized fes object. Operate on that within the loop. # do the deepcopy here since there seem to be issues if it's done after data is added # For example, the scikit-learn kde object fails to deepopy. @@ -302,7 +301,6 @@ def dddeltag(c, scalef=1, n=nspline): ) if method == "kde": - kde_parameters = {} # set the sigma for the spline. kde_parameters["bandwidth"] = 0.5 * ((chi_max - chi_min) / nbins) @@ -315,7 +313,6 @@ def dddeltag(c, scalef=1, n=nspline): f_i_kde = results["f_i"] # kde results if method in ["unbiased", "biased", "simple"]: - spline_parameters = {} if method == "unbiased": spline_parameters["spline_weights"] = "unbiasedstate" diff --git a/pymbar/confidenceintervals.py b/pymbar/confidenceintervals.py index 089d0fdb..0a196bb9 100644 --- a/pymbar/confidenceintervals.py +++ b/pymbar/confidenceintervals.py @@ -31,7 +31,6 @@ def order_replicates(replicates, K): - """ TODO: Add description for this function and types for parameters @@ -56,7 +55,7 @@ def order_replicates(replicates, K): sigma += sigmacorr yi = [] - for (replicate_index, replicate) in enumerate(replicates): + for replicate_index, replicate in enumerate(replicates): yi.append(replicate["error"] / sigma) yiarray = np.asarray(yi) sortedyi = np.zeros(np.shape(yiarray)) @@ -76,7 +75,6 @@ def order_replicates(replicates, K): def anderson_darling(replicates, K): - """ TODO: Description here @@ -300,7 +298,7 @@ def generate_confidence_intervals(replicates, K): b = 1.0 # how many dimensions in the data? - for (replicate_index, replicate) in enumerate(replicates): + for replicate_index, replicate in enumerate(replicates): # Compute fraction of free energy differences where error <= alpha sigma # We only count differences where the analytical difference is larger than a cutoff, so that the results will not be limited by machine precision. if dim == 0: diff --git a/pymbar/fes.py b/pymbar/fes.py index 7bd729f7..b1dbbd05 100644 --- a/pymbar/fes.py +++ b/pymbar/fes.py @@ -229,7 +229,6 @@ def generate_fes( n_bootstraps=0, seed=-1, ): - """ Given an intialized MBAR object, a set of points, the desired energies at that point, and a method, generate @@ -439,7 +438,6 @@ def generate_fes( return result_vals # should we return results under some other conditions? def _setup_fes_histogram(self, histogram_parameters): - """ Does initial processsing of histogram_parameters @@ -476,7 +474,6 @@ def _setup_fes_histogram(self, histogram_parameters): self.histogram_datas = None def _generate_fes_histogram(self, b, x_n, w_nb, log_w_nb): - """ Parameters ---------- @@ -603,7 +600,6 @@ def _generate_fes_histogram(self, b, x_n, w_nb, log_w_nb): self.histogram_datas.append(histogram_data) def _setup_fes_kde(self, kde_parameters): - """ Does initial processsing of kde_parameters @@ -652,7 +648,6 @@ def _setup_fes_kde(self, kde_parameters): self.kde = kde def _generate_fes_kde(self, b, x_n, w_n): - """ Given an fes object with the kde data set up, determine the information necessary to define a FES using a kernel density approximation @@ -704,7 +699,6 @@ def _generate_fes_kde(self, b, x_n, w_n): self.kdes.append(kde) def _setup_fes_spline(self, spline_parameters): - """ Does initial processsing of spline_parameters @@ -813,7 +807,6 @@ def _setup_fes_spline(self, spline_parameters): self.fes_functions = None def _get_initial_spline_points(self): - """ Uses information from spline_parameters to construct initial points to create a spline frmo which to start the minimization. @@ -888,7 +881,6 @@ def _get_initial_spline_points(self): return xinit, yinit def _get_initial_spline(self, xinit, yinit): - """ Uses information from spline_parameters to construct initial points to create a spline frmo which to start the minimization. @@ -977,7 +969,6 @@ def _get_initial_spline(self, xinit, yinit): return spline_data def _generate_fes_spline(self, b, x_n, w_n): - """ Given an fes object with the spline set up, determine the information necessary to define a FES. @@ -1046,7 +1037,6 @@ def _generate_fes_spline(self, b, x_n, w_n): firsttime = True while dg > tol: # until we reach the tolerance. - f = func(xi, *spline_args) # we need some error handling: if we stepped too far, we should go back @@ -1109,7 +1099,6 @@ def _generate_fes_spline(self, b, x_n, w_n): @staticmethod def _calculate_information_criteria(nparameters, minus_log_likelihood, N): - """ Calculate and store various informaton criterias @@ -1369,7 +1358,6 @@ def _get_fes_histogram( raise ParameterError("Specified reference point for FES not given") if reference_point in ["from-lowest", "from-specified", "all-differences"]: - if reference_point == "from-lowest": # Determine free energy with lowest free energy to serve as reference point j = histogram_data["f"].argmin() @@ -1597,7 +1585,6 @@ def _get_fes_kde( df_i = None elif uncertainty_method == "bootstrap": - if self.kdes is None: raise ParameterError( f"Cannot calculate bootstrap error of boostrap KDE's not determined" @@ -1867,7 +1854,6 @@ def prob(x): self.mc_data["g"] = guse # statistical efficiency used for subsampling def get_confidence_intervals(self, xplot, plow, phigh, reference="zero"): - """ Parameters ---------- @@ -1937,7 +1923,6 @@ def get_confidence_intervals(self, xplot, plow, phigh, reference="zero"): return return_vals def get_mc_data(self): - """convenience function to retrieve MC data Parameters @@ -1964,7 +1949,6 @@ def get_mc_data(self): return self.mc_data def _get_MC_loglikelihood(self, x_n, w_n, spline_weights, spline, xrange): - """ Parameters ---------- @@ -2023,7 +2007,6 @@ def expk(x, kf): return loglikelihood def _MC_step(self, x_n, w_n, stepsize, xrange, spline_weights, logprior): - """sample over the posterior space of the FES as splined. Parameters @@ -2114,7 +2097,6 @@ def prob(x): return results def _bspline_calculate_f(self, xi, x_n, w_n): - """Calculate the maximum likelihood / KL divergence of the FES represented using B-splines. Parameters @@ -2321,7 +2303,6 @@ def dexpf(x, index): return g def _bspline_calculate_h(self, xi, x_n, w_n): - """Calculate the Hessian of the maximum likelihood / KL divergence of the FES represented using B-splines. Parameters @@ -2411,7 +2392,6 @@ def ddexpf(x, index_i, index_j): for i in range(nspline - 1): for j in range(0, i + 1): if np.abs(i - j) <= kdegree: - # now compute the expectation of each derivative pE = self._integrate( ddexpf, diff --git a/pymbar/mbar.py b/pymbar/mbar.py index 27291ed2..893a46a9 100644 --- a/pymbar/mbar.py +++ b/pymbar/mbar.py @@ -57,6 +57,7 @@ JAX_SOLVER_PROTOCOL = mbar_solvers.JAX_SOLVER_PROTOCOL BOOTSTRAP_SOLVER_PROTOCOL = mbar_solvers.BOOTSTRAP_SOLVER_PROTOCOL + # ========================================================================= # MBAR class definition # ========================================================================= @@ -364,7 +365,6 @@ def __init__( protocols = {pnames[0]: solver_protocol, pnames[1]: bootstrap_solver_protocol} for defl, rob, pname in zip(defaults, robusts, pnames): - prot = protocols[pname] if prot is None or prot == "default": prot = defl @@ -972,7 +972,6 @@ def compute_expectations_inner( A_n[i, :] = A_n[i, :] + (A_min[i] - logfactors[i]) if return_theta: - # Note: these variances will be the same whether or not we # subtract a different constant from each A_i # for efficency, output theta in block form @@ -1029,7 +1028,6 @@ def compute_expectations_inner( # ========================================================================= def compute_covariance_of_sums(self, d_ij, K, a): - """ We wish to calculate the variance of a weighted sum of free energy differences. for example ``var(\\sum a_i df_i)``. @@ -1919,7 +1917,6 @@ def _computeUnnormalizedLogWeights(self, u_n): return -1.0 * logsumexp(self.f_k + u_n[:, np.newaxis] - self.u_kn.T, b=self.N_k, axis=1) def _initialize_with_bar(self, u_kn, f_k_init=None): - """ Internal method for intializing free energies simulations with BAR. diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index c87ad2d8..31a24ccd 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -12,20 +12,36 @@ try: #### JAX related imports if force_no_jax: + # Capture user-disabled JAX instead "JAX not found" raise ImportError("Jax disabled by force_no_jax in mbar_solvers.py") - from jax.config import config - - config.update("jax_enable_x64", True) - - from jax.numpy import exp, sum, newaxis, diag, dot, s_ - from jax.numpy import pad as npad - from jax.numpy.linalg import lstsq - import jax.scipy.optimize as optimize_maybe_jax - from jax.scipy.special import logsumexp - - from jax import jit as jit_or_passthrough - - use_jit = True + try: + from jax.config import config + + config.update("jax_enable_x64", True) + + from jax.numpy import exp, sum, newaxis, diag, dot, s_ + from jax.numpy import pad as npad + from jax.numpy.linalg import lstsq + import jax.scipy.optimize as optimize_maybe_jax + from jax.scipy.special import logsumexp + + from jax import jit as jit_or_passthrough + + use_jit = True + except ImportError: + # Catch no JAX and throw a warning + warnings.warn( + "\n" + "********* JAX NOT FOUND *********\n" + " PyMBAR can run faster with JAX \n" + " But will work fine without it \n" + "Either install with pip or conda:\n" + " pip install pybar[jax] \n" + " OR \n" + " conda install pymbar \n" + "*********************************" + ) + raise # Continue with the raised Import Error except ImportError: # No JAX found, overlap imports @@ -431,7 +447,6 @@ def mbar_W_nk(u_kn, N_k, f_k): def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None): - """ Determine dimensionless free energies by a combination of Newton-Raphson iteration and self-consistent iteration. Picks whichever method gives the lowest gradient. @@ -497,7 +512,6 @@ def adaptive(u_kn, N_k, f_k, tol=1.0e-8, options=None): min_sc_iter = options["min_sc_iter"] warn = "Did not converge." for iteration in range(0, maxiter): - if use_jit: (f_sci, g_sci, gnorm_sci, f_nr, g_nr, gnorm_nr) = jax_core_adaptive( u_kn, N_k, f_k, options["gamma"] diff --git a/pymbar/other_estimators.py b/pymbar/other_estimators.py index 7dde3901..9c2dab6c 100644 --- a/pymbar/other_estimators.py +++ b/pymbar/other_estimators.py @@ -290,7 +290,6 @@ def bar( # Iterate to convergence or until maximum number of iterations has been exceeded. for iteration in range(maximum_iterations + 1): - DeltaF_old = DeltaF if method == "false-position": @@ -370,7 +369,6 @@ def bar( raise ConvergenceError(message) if compute_uncertainty: - ############# # Compute asymptotic variance estimate using Eq. 10a of Bennett, # 1976 (except with n_1_1^2 in the second denominator, it is diff --git a/pymbar/tests/test_bar.py b/pymbar/tests/test_bar.py index 446abc47..29ffda5c 100644 --- a/pymbar/tests/test_bar.py +++ b/pymbar/tests/test_bar.py @@ -58,7 +58,6 @@ def test_sample(system_generator): def test_bar_free_energies(bar_and_test): - """Can bar calculate moderately correct free energy differences?""" bars, test = bar_and_test["bars"], bar_and_test["test"] @@ -100,7 +99,6 @@ def test_bar_free_energies(bar_and_test): def test_bar_overlap(): - for system_generator in system_generators: name, test = system_generator() x_n, u_kn, N_k_output, s_n = test.sample(N_k, mode="u_kn") diff --git a/pymbar/tests/test_exp.py b/pymbar/tests/test_exp.py index 6936589c..c1181c26 100644 --- a/pymbar/tests/test_exp.py +++ b/pymbar/tests/test_exp.py @@ -56,7 +56,6 @@ def test_sample(system_generator): def test_EXP_free_energies(exp_and_test): - """Can exp calculate moderately correct free energy differences?""" exps, test = exp_and_test["exps"], exp_and_test["test"] diff --git a/pymbar/tests/test_fes.py b/pymbar/tests/test_fes.py index b95ba2d5..f2d2ec35 100644 --- a/pymbar/tests/test_fes.py +++ b/pymbar/tests/test_fes.py @@ -17,7 +17,6 @@ def generate_fes_data(ndim=1, nsamples=1000, K0=20.0, Ku=100.0, gridscale=0.2, xrange=None): - x0 = np.zeros([ndim]) # center of base potential numbrellas = 1 nperdim = np.zeros([ndim], int) @@ -98,7 +97,6 @@ def bias_potential(x, k_bias): @pytest.fixture(scope="module") def fes_1d(): - gridscale = 0.2 nbinsperdim = 15 xrange = [[-3, 3]] @@ -190,7 +188,6 @@ def fes_1d(): @pytest.fixture(scope="module") def fes_2d(): - xrange = [[-3, 3], [-3, 3]] ndim = 2 nsamples = 300 @@ -320,7 +317,6 @@ def fes_2d(): ], ) def test_1d_fes_histogram(fes_1d, reference_point): - fes = fes_1d["fes"] histogram_parameters = dict() @@ -345,7 +341,6 @@ def test_1d_fes_histogram(fes_1d, reference_point): def base_1d_fes_kde(fes_1d, gen_kwargs, reference_point): - fes = fes_1d["fes"] kde_parameters = dict() @@ -405,7 +400,6 @@ def test_1d_fes_kde_bootstraped(fes_1d): def base_1d_fes_spline(fes_1d, gen_kwargs, reference_point): - fes = fes_1d["fes"] bin_centers = fes_1d["bin_centers"] fes_analytical = fes_1d["fes_analytical"] @@ -487,7 +481,6 @@ def test_1d_fes_spline_bootstraped(fes_1d): ], ) def test_2d_fes_histogram(fes_2d, reference_point): - """testing fes_generate_fes and fes_get_fes in 2D""" fes = fes_2d["fes"] @@ -536,7 +529,6 @@ def test_2d_fes_histogram(fes_2d, reference_point): ], ) def test_2d_fes_kde(fes_2d, gen_kwargs, reference_point): - fes = fes_2d["fes"] fes_analytical = fes_2d["fes_analytical"] diff --git a/pymbar/tests/test_mbar.py b/pymbar/tests/test_mbar.py index 378db1ae..b14055ee 100644 --- a/pymbar/tests/test_mbar.py +++ b/pymbar/tests/test_mbar.py @@ -181,7 +181,6 @@ def test_sample(system_generator): ], ) def test_mbar_free_energies(mbar_and_test, uncertainty_method): - """Can MBAR calculate moderately correct free energy differences?""" mbar, test = mbar_and_test["mbar"], mbar_and_test["test"] @@ -216,7 +215,6 @@ def test_mbar_initialization(fixed_harmonic_sample, method): def test_mbar_compute_expectations_position_averages(mbar_and_test): - """Can MBAR calculate E(x_n)??""" mbar, test, x_n = mbar_and_test["mbar"], mbar_and_test["test"], mbar_and_test["x_n"] @@ -231,7 +229,6 @@ def test_mbar_compute_expectations_position_averages(mbar_and_test): def test_mbar_compute_expectations_position_differences(mbar_and_test): - """Can MBAR calculate E(x_n)??""" mbar, test, x_n = mbar_and_test["mbar"], mbar_and_test["test"], mbar_and_test["x_n"] results = mbar.compute_expectations(x_n, output="differences") @@ -244,7 +241,6 @@ def test_mbar_compute_expectations_position_differences(mbar_and_test): def test_mbar_compute_expectations_position2(mbar_and_test): - """Can MBAR calculate E(x_n^2)??""" mbar, test, x_n = mbar_and_test["mbar"], mbar_and_test["test"], mbar_and_test["x_n"] @@ -258,7 +254,6 @@ def test_mbar_compute_expectations_position2(mbar_and_test): def test_mbar_compute_expectations_potential(mbar_and_test): - """Can MBAR calculate E(u_kn)??""" mbar, test, u_kn = mbar_and_test["mbar"], mbar_and_test["test"], mbar_and_test["u_kn"] @@ -329,7 +324,6 @@ def multiExpectationAssertion(results, test, state=1): def test_mbar_compute_multiple_expectations(mbar_and_test): - """Can MBAR calculate E(u_kn)??""" mbar, test, x_n, u_kn = ( @@ -347,7 +341,6 @@ def test_mbar_compute_multiple_expectations(mbar_and_test): def test_mbar_compute_multiple_expectations_more_dims(mbar_and_test_kln): - """Can MBAR calculate E(u_kn) with 3 dimensions??""" mbar, test, x_n, u_kn = ( @@ -367,7 +360,6 @@ def test_mbar_compute_multiple_expectations_more_dims(mbar_and_test_kln): def test_mbar_compute_entropy_and_enthalpy(mbar_and_test, with_uxx=True): - """Can MBAR calculate f_k, and s_k ??""" mbar, test, x_n, u_kn = ( @@ -472,7 +464,6 @@ def test_mbar_compute_overlap_nonanalytical(mbar_and_test): def test_mbar_weights(mbar_and_test): - """testing weights""" mbar = mbar_and_test["mbar"] @@ -491,7 +482,6 @@ def test_mbar_weights(mbar_and_test): ], ) def test_mbar_computePerturbedFreeEnergeies(system_generator, mode, bad_n): - """testing compute_perturbed_free_energies""" # only do MBAR with the first and last set @@ -526,7 +516,6 @@ def test_mbar_computePerturbedFreeEnergeies(system_generator, mode, bad_n): def test_mbar_compute_expectations_inner(mbar_and_test): - """Can MBAR calculate general expectations inner code (note: this just tests completion)""" mbar, test, x_n, u_kn = ( diff --git a/pymbar/tests/test_timeseries.py b/pymbar/tests/test_timeseries.py index 11c9e185..e5455ba7 100644 --- a/pymbar/tests/test_timeseries.py +++ b/pymbar/tests/test_timeseries.py @@ -75,7 +75,6 @@ def test_statistical_inefficiency_fft(data): @has_statmodels def test_statistical_inefficiency_fft_gaussian(): - # Run multiple times to get things with and without negative "spikes" at C(1) for i in range(5): x = np.random.normal(size=100000) diff --git a/pymbar/testsystems/exponential_distributions.py b/pymbar/testsystems/exponential_distributions.py index 07aff03b..02031ebd 100644 --- a/pymbar/testsystems/exponential_distributions.py +++ b/pymbar/testsystems/exponential_distributions.py @@ -74,7 +74,6 @@ def analytical_standard_deviations(self): return np.sqrt(self.rates**-2.0) def analytical_observable(self, observable="position"): - if observable == "position": return self.analytical_means() if observable == "position^2": diff --git a/pymbar/testsystems/harmonic_oscillators.py b/pymbar/testsystems/harmonic_oscillators.py index 918490c5..0eac94c0 100644 --- a/pymbar/testsystems/harmonic_oscillators.py +++ b/pymbar/testsystems/harmonic_oscillators.py @@ -81,7 +81,6 @@ def analytical_standard_deviations(self): return (self.beta * self.K_k) ** -0.5 def analytical_observable(self, observable="position"): - if observable == "position": return self.analytical_means() if observable == "potential energy": diff --git a/pymbar/timeseries.py b/pymbar/timeseries.py index 48830444..5b0bb755 100644 --- a/pymbar/timeseries.py +++ b/pymbar/timeseries.py @@ -177,7 +177,6 @@ def statistical_inefficiency(A_n, B_n=None, fast=False, mintime=3, fft=False): t = 1 increment = 1 while t < N - 1: - # compute normalized fluctuation correlation function at time t C = np.sum(dA_n[0 : (N - t)] * dB_n[t:N] + dB_n[0 : (N - t)] * dA_n[t:N]) / ( 2.0 * float(N - t) * sigma2_AB diff --git a/setup.py b/setup.py index c01d9067..2aae1945 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,10 @@ install_requires=["numpy>=1.12", "scipy", "numexpr", - "jaxlib;platform_system!='Windows'", - "jax;platform_system!='Windows'" ], + extras_require={ + "jax": ["jaxlib;platform_system!='Windows'", + "jax;platform_system!='Windows'" + ], + }, )