Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify MBAR and Optimization docs #144

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions docs/user_guide/4.5Optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@

## 1. Theory

Automatic differentiation is fundamental to DMFF and aids in neural network optimization. During training, it computes the derivatives from output to input using backpropagation, optimizing parameters through gradient descent. With its efficiency in optimizing high-dimensional parameters, this technique isn't limited to neural networks but suits any framework following the "input parameters → model computation → output" sequence, such as molecular dynamics (MD) simulations. Hence, using automatic differentiation and referencing experimental or ab initio data, we can optimize force field parameters by computing the output's derivative with respect to input parameters.
Automatic differentiation is a crucial component of DMFF and plays a significant role in optimizing neural networks. This technique computes the derivatives of output with respect to input using backpropagation, so parameters optimization can be conducted using gradient descent algorithms. With its efficiency in optimizing high-dimensional parameters, this technique is not limited to training neural networks but is also suitable for any physical model optimization (i.e., molecular force field in the case of DMFF). A typical optimization recipe needs two key ingradients: 1. gradient evaluation, which can be done easily using JAX; and 2. an optimizer that takes gradient as inputs, and update parameters following certain optimization algorithm. To help the users building optimization workflows, DMFF provides an wrapper API for optimizers implemented in [Optax](https://github.com/google-deepmind/optax), which is introduced here.

## 2. Function module

Imports: Importing necessary modules and functions from `jax` and `optax`.

Function `periodic_move`:
- Creates a function to perform a periodic update on parameters. If the update causes the parameters to exceed a given range, they are wrapped around in a periodic manner (like an angle that wraps around after 360 degrees).

Function `genOptimizer`:
- It's a function to generate an optimizer based on user preferences.
- Depending on the arguments, it can produce various optimization schemes, such as SGD, Nesterov, Adam, and others.
- Depending on the arguments, it can produce various optimization schemes, such as SGD, Nesterov, Adam, etc.
- Supports learning rate schedules like exponential decay and warmup exponential decay.
- The optimizer can be further augmented with features like gradient clipping, periodic parameter wrapping, and keeping parameters non-negative.

Expand Down
111 changes: 72 additions & 39 deletions docs/user_guide/4.6MBAR.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,122 @@

## 1. Theory

In molecular dynamics (MD) simulations, the deep computational graph spanning the entire trajectory incurs significant temporal and computational costs. This limitation can be circumvented through trajectory reweighting schemes. In DMFF, the reweighting algorithm is incorporated into the MBAR method, extending the differentiable estimators for average properties and free energy. Although differentiable estimation of dynamic quantities remains a challenge, introducing the reweighted MBAR estimator has optimized the fitting of thermodynamic properties.
In molecular dynamics (MD) simulations, the deep computational graph spanning the entire trajectory incurs significant temporal and computational costs. This limitation can be circumvented through trajectory reweighting schemes. In DMFF, the reweighting algorithm is incorporated into the MBAR method, extending the differentiable estimators for average properties and free energy. Although differentiable estimation of dynamic properties remains a challenge, introducing the reweighted MBAR estimator has largely eased the fitting of thermodynamic properties.

In the MBAR theory, it is assumed that there are K ensembles defined by potential energies
In the MBAR theory, it is assumed that there are K ensembles defined by (effective) potential energies

$$
u_{i}(x)(i=1,2,3,……,K)
$$
```math
\tag{1}
u_{i}(x)\ (i=1,2,3,……,K)
```

For each ensemble, the Boltzmann weight, partition function, and probability function are defined as:

$$
w_i = \exp(-\beta_i u_i(x))
c_i = \int dx \cdot w_i(x)
p_i(x) = \frac{w_i(x)}{c_i}
$$
```math
\tag{2}
\begin{align}
w_i &= \exp(-\beta_i u_i(x)) \\
c_i &= \int dx \cdot w_i(x) \\
p_i(x) &= \frac{w_i(x)}{c_i}
\end{align}
```

For each ensemble $i$, select$N_{i}$ configurations, represented by{${x_{in}}$}$(where n=1,2,3,……,N_{i})$, and the total number of configurations across ensembles is represented by{${x_{n}}$}(n=1,2,3,……,N), where N is:
For each ensemble $i$, select $N_{i}$ configurations, represented by { ${x_{in}}$ } $n=1,2,3,……,N_i$ , and the total number of configurations across ensembles is represented by { ${x_{n}}$ } ( $n=1,2,3,……,N$ ), where N is:

$$
N = \sum_{i=1}^{K} N_i
$$
```math
\tag{3}
N = \sum_{i=1}^{K} N_i
```

Within the context of MBAR, for any ensemble K, the weighted average of the observable is defined as:

$$
\hat{c}_i = \sum_{n=1}^{N} \frac{w_{i}(x_n)}{\sum_{k=1}^{K} N_{k} \hat{c}_k^{-1} w_{k}(x_n)} (1)
$$
```math
\tag{4}
\hat{c}_i = \sum_{n=1}^{N} w_{i}(x_n) \cdot \left(\sum_{k=1}^{K} N_{k} \hat{c}_k^{-1} w_{k}(x_n)\right)^{-1}
```

To compute the average of a physical quantity A in ensemble i, one can utilize the above values to define a virtual ensemble j and provide its corresponding Boltzmann weight and partition function:
To compute the average of a physical quantity $A$ in ensemble $i$, one can utilize the above values to define a virtual ensemble $j$ , with its corresponding Boltzmann weight and partition function:

$$w_j = w_i(x)A(x)$$
$$c_i = \int dx \cdot w_j(x)$$
```math
\tag{5}
\begin{align}
w_j &= w_i(x)A(x) \\
c_i &= \int dx \cdot w_j(x)
\end{align}
```

Thus, the ensemble average of A is:
$$\langle A \rangle_i = \frac{\hat{c}_j}{\hat{c}_i} = \frac{\int dx \cdot w_i(x)A(x)}{\int dx \cdot w_i(x)}$$

Thus, the MBAR theory provides a method for estimating the average of physical properties using multiple samples.
```math
\tag{6}
\langle A \rangle_i = \frac{\hat{c}_j}{\hat{c}_i} = \frac{\int dx \cdot w_i(x)A(x)}{\int dx \cdot w_i(x)}
```

Thus, the MBAR theory provides a method to estimate the ensemble averages using multiple samples from different ensembles.

In the MBAR framework, $\hat{c}_i$ in Eqn (4) needs to be solved iteratively; however, the differentiable reweighting algorithm can simplify this estimation process. During the gradient descent parameter optimization, the parameters undergo only small changes in each training cycle. This allows for the usage of samples from the previous cycles to evaluate the target ensemble that is being optimized. So resampling is not necessary until the target ensemble deviates significantly from the sampling ensemble. This reduces the time and computational cost of the optimization considerably.

In the MBAR framework, $\hat{c}_i$in Equation (1) requires iterative solution; however, the reweighting algorithm can simplify this estimation process. During gradient descent training for parameter optimization, the parameters undergo only slight perturbations in each training cycle. This allows for the continued use of samples from the previous cycle, such that resampling is not necessary until the optimized ensemble deviates significantly from the sampling ensemble, considerably reducing optimization time and computational cost. In the reweighted MBAR estimator, we define two types of ensembles: the sampling ensemble, from which all samples are extracted (assuming there are m samples, labeled as m=1, 2, 3, …, M), and the target ensemble, which needs optimization (corresponding to the above i, j, labeled as p, q). The sampling ensemble is updated only when necessary and does not need to be differentiable. Its data can be generated by external samplers like OpenMM. Hence, $\hat{c}_i$ can be transformed into:
In the reweighted MBAR estimator, we define two types of ensembles: the sampling ensemble, from which all samples are extracted (labeled as $m=1, 2, 3, …, M$ ), and the target ensemble, which needs optimization (labeled as $p, q$, corresponding to the indices $i, j$ in Eqn (6)). The sampling ensemble is updated only when necessary and does not need to be differentiable. Its data can be generated by external samplers like OpenMM. Hence, $\hat{c}_i$ can be transformed into:

$$\hat{c}_p = \sum_{n=1}^{N} w_{p}(x_n) \left( \sum_{m=1}^{M} N_{m} \hat{c}_m^{-1} w_{m}(x_n) \right)^{-1}(2)$$
```math
\tag{7}
\hat{c}_p = \sum_{n=1}^{N} w_{p}(x_n) \left( \sum_{m=1}^{M} N_{m} \hat{c}_m^{-1} w_{m}(x_n) \right)^{-1}
```

Every time resampling is needed, the method of iteratively solving Equation(1) is utilized for updating $\hat{c}_m$ and storage until the next sampling. Subsequently, during the parameter optimization process, Equation (2) is employed for computating $\hat{c}_p$ and serves as a differentiable estimator.
When resample happens, Eqn. (4) is solved iteratively using standard MBAR to update $\hat{c}_m$, which is stored and used to evaluate $\hat{c}_p$ until the next resampling. Subsequently, during the parameter optimization process, Eqn (7) is employed to compute $\hat{c}_p$, serving as a differentiable estimator.

Below, we illustrate the workflow of how to use MBAR Estimator in DMFF through a case study.

If all sampling ensembles are defined as a single ensemble $w_{0}(x)$, and the target ensemble is defined as $w_{p}(x)$, and for physical quantity A, we have:

$$w_q(x) = w_p(x) A(x)$$
```math
\tag{8}
w_q(x) = w_p(x) A(x)
```

and define:

$$\Delta u_{p_0} = u_p(x) - u_0(x)$$
```math
\tag{9}
\Delta u_{p_0} = u_p(x) - u_0(x)
```

then:

$$\langle A \rangle_p = \frac{\hat{c}_q}{\hat{c}_p}= \frac{\sum_{n=1}^{N} A(x_n) \exp(-\beta \Delta u_{p_0}(x_n))}{\sum_{n=1}^{N} \exp(-\beta \Delta u_{p_0}(x_n))}$$
```math
\tag{10}
\langle A \rangle_p = \frac{\hat{c}_q}{\hat{c}_p} = \left(\sum_{n=1}^{N} A(x_n) \exp(-\beta \Delta u_{p_0}(x_n))\right) \cdot \left(\sum_{n=1}^{N} \exp(-\beta \Delta u_{p_0}(x_n))\right)^{-1}
```

Refers to equations above, this equation indicates that the trajectory reweighting algorithm is a special case of the reweighted MBAR estimator.

In DMFF, when calculating the average of the physical quantity A, the formula is expressed as:

$$
```math
\tag{11}
\langle A \rangle_p = \sum_{n=1}^{N} W_n A(x_n)
$$
```

where

$$
```math
\tag{12}
\Delta U_{mp} = U_m(x_n) - U_p(x_n)
$$
```

$$
W_n = \frac{\left[\sum_{m=1}^{M} N_m e^{\hat{f}_m -\beta \Delta U_{mp}(x_n)}\right]^{-1}}{\sum_{n=1}^{N} \left[ \sum_{m=1}^{M} N_m e^{\hat{f}_m -\beta \Delta U_{mp}(x_n)} \right]^{-1}}
$$
```math
\tag{13}
W_n = \left[\sum_{m=1}^{M} N_m e^{\hat{f}_m -\beta \Delta U_{mp}(x_n)}\right]^{-1} \cdot \left(\sum_{n=1}^{N} \left[ \sum_{m=1}^{M} N_m e^{\hat{f}_m -\beta \Delta U_{mp}(x_n)} \right]^{-1}\right)^{-1}
```

$\hat{f}_m$ is the partition function of the sampling state. W is the MBAR weight for each sample. Finally, the effective sample size is given, based on which one can judge the deviation of the sampling ensemble from the target ensemble:

$$
n_{\text{eff}} = \frac{\left(\sum_{n=1}^{N} W_n\right)^2}{\sum_{n=1}^{N} W_n^2}
$$
```math
\tag{14}
n_{\text{eff}} = \left(\sum_{n=1}^{N} W_n\right)^2\cdot\left(\sum_{n=1}^{N} W_n^2\right)^{-1}
```

When $n_{eff}$ is too small, it indicates that the current sampling ensemble deviates too much from the target ensemble and needs to be resampled.
When $n_{eff}$ is too small, it indicates that the current sampling ensemble deviates too much from the target ensemble and resample is needed.

Here is a graphical representation of the workflow mentioned above:

Expand Down
Loading