Skip to content

Commit

Permalink
Release DMFF v0.1.0 (#38)
Browse files Browse the repository at this point in the history
* add nblist wrapper and its docs

* fix(generator): existing jax bond/angle generators

* fix: LJ bug; all classical code can be jitted; bug at test_gaff2_force

* fix test_gaff2_force bug

* fix: non-differentiable error, move args check in the api.py

* fix: remove redundancy `box=jnp.array(box)`; confirm isinstance_jnp is jit-compatiable

* refine(pme): code prettify

* refine(utils): code prettify

* add(nblist): add jit and auto update nblist

* refine(classical): rm unused import

* refactor(classical): add dispCorr

* add(classical): add free energy support

* refactor renderXML with clean commits (#25)

* feat:rewrite renderXML with very clean commit history

* fix bug in Torsion renderXML

* add test_utils as a placehold

* fix: fix typo in api.py

* docs: add renderXML related api usages

* refine(fep): rm debugging codes

* add(cicd): CI/CD workflows

* add(ut): ut for regularize_pairs and buffer_scales

* fix(CI/CD): activate conda env

* fix(CI/CD): wrong number in test_nblist

* add(requirement): add dependencies list

* add(CI/CD): unittest workflows

* add(requirements): dependencies list

* add test_utils.py

* fix: modified unit test results

* add `r` to avoid latex being recognized as an escape character

* fix(ut): wrong number in test_nblist

* Fix unit test related problems (#27)

* add test_utils.py

* fix: modified unit test results

* add `r` to avoid latex being recognized as an escape character

* fix(ut): wrong number in test_nblist

Co-authored-by: Yingze Wang <[email protected]>

* refine(ut): code prettify in test_nblist

* Add CI/CD Workflows (#26)

* add(CI/CD): unittest workflows

* add(requirements): dependencies list

* refine(ut): code prettify in test_nblist

* update(api): switch default dispcorr to False

* Chore: clean admp module up

* chore: clean up classical and api.py

* fix(ut): withdraw last commit about fix wrong number in test_nblist

* update(api): fix dispcorr countmat bugs

* update(pme): code prettify & gmx ewald coeff determine

* add(constants): module to control constants

* add(unittest): fep ut

* update(test_nblist): unused imports

* update(gitignore): acpype cache

* refactor(test_classical): split to several files

* refactor(test_classical): rename test_classical

* fix(inter): PME in classical forcefield

* update(fep): use dmff.common.constants

* update(pme): default args in setup_ewald_parameters

* Fix the nb list description in doc

* update mkdocs.yml and a simple test of api generator

* fix: add step_pol arg in ADMPPmeForce __init__ to fix it can not be jitted bug

* add step_pol arg in ADMPPmeForce __init__ to fix it can not be jitted bug

* Improve docstrings in sgnn

* add read_input_info in dmff.admp.parse.py to deal with info in pdb and xml (input:pdb,xml output: multipoles/polarizabilities/tholes...)
use forcefield.xml instead of other xmls, add read_admp_xml in parser in order to read the new form of xml file
change line 795 in pme.py and add distribute_matrix in pairwise.py in order to wrap a jit outside my admp_calculator

* Update some examples and ref_outs

* feat: auto gen docs refs

* fix: typo in requirements.txt

* add change leading terms in api.py

* add fluctuated leading term computation in dmff.api

* add fluctuated leading term compute in dmff.api

* use dmff.api to deal with input ; wrap admp_calculator to change fixed params into fluctuated one amf jit

* wrap compute leading term in run.py

* add a markdown and a generate_calculator outside admp_calculator function

* feat(classical): Hamiltonian can create total energy function now. (#33)

* Docs refine (#34)

* add(CI/CD): unittest workflows

* add(requirements): dependencies list

* update(doc): fix bad render of \hat

* update(doc): add examples and doc struct refine

* Improve the fluctuating charge jupyter notebook demo

* Delete parser_bk.py

* Fix the fluctuating atom charge demo

* feat: update docs configs

* fix: fix test_nblist bug

* update: license in docs

* fix: import missing in api.py (may caused by formatter)

* FIX: wrong links in docs and add requirements (#37)

* update(doc): fix bad render of \hat

* update(doc): add examples and doc struct refine

* hotfix(docs): license bugs

Co-authored-by: Kuang Yu <[email protected]>
Co-authored-by: Roy Kid <[email protected]>
Co-authored-by: Jichen Li <[email protected]>
Co-authored-by: crone <[email protected]>
Co-authored-by: WangXinyan940 <[email protected]>
  • Loading branch information
6 people authored Jun 14, 2022
1 parent 00b3a41 commit 846e3b8
Show file tree
Hide file tree
Showing 76 changed files with 18,985 additions and 10,443 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/ut.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: DMFF's python tests.

on:
push:
pull_request:

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: |
source $CONDA/bin/activate
$CONDA/bin/conda update -n base -c defaults conda
conda install pip
conda update pip
conda install numpy openmm pytest -c conda-forge
pip install jax jax_md
- name: Install DMFF
run: |
source $CONDA/bin/activate
pip install .
- name: Run Tests
run: |
source $CONDA/bin/activate
pytest -vs tests/
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -775,4 +775,7 @@ FodyWeavers.xsd

### VisualStudio Patch ###
# Additional files built by Visual Studio
.vscode/**
.vscode/**

# acpype cache
*.acpype/
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ The behavior of organic molecular systems (e.g., protein folding, polymer struct

+ [1. Introduction](docs/user_guide/introduction.md)
+ [2. Installation](docs/user_guide/installation.md)
+ [3. Compute energy and forces](docs/user_guide/compute.md)
+ [4. Compute gradients with auto differentiable framework](docs/user_guide/auto_diff.md)
+ [5. Theories](docs/user_guide/theory.md)
+ [6. Introduction to force field xml files](docs/user_guide/xml_spec.md)
+ [3. Basic usage](docs/user_guide/usage.md)
+ [4. XML format force field](docs/user_guide/xml_spec.md)
+ [5. Theory](docs/user_guide/theory.md)

## Developer Guide
+ [1. Introduction](docs/dev_guide/introduction.md)
+ [2. Architecture](docs/dev_guide/arch.md)
+ [3. Convention](docs/dev_guide/convention.md)
+ [2. Software architecture](docs/dev_guide/arch.md)
+ [3. Coding conventions](docs/dev_guide/convention.md)
+ [4. Document writing](docs/dev_guide/write_docs.md)

## Modules
+ [1. ADMP](docs/modules/admp.md)
Expand Down
4 changes: 3 additions & 1 deletion dmff/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
import dmff.settings
from .settings import *
from .common.nblist import NeighborList
from .api import Hamiltonian
15 changes: 9 additions & 6 deletions dmff/admp/disp_pme.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from functools import partial

import jax.numpy as jnp
from jax import vmap, value_and_grad
from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import pbc_shift
from dmff.admp.pairwise import (distribute_dispcoeff, distribute_scalar,
distribute_v3)
from dmff.admp.pme import setup_ewald_parameters
from dmff.admp.recip import generate_pme_recip, Ck_6, Ck_8, Ck_10
from dmff.admp.pairwise import distribute_scalar, distribute_v3, distribute_dispcoeff
from functools import partial
from dmff.admp.recip import Ck_6, Ck_8, Ck_10, generate_pme_recip
from dmff.admp.spatial import pbc_shift
from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs
from jax import value_and_grad, vmap


class ADMPDispPmeForce:
'''
Expand Down
76 changes: 16 additions & 60 deletions dmff/admp/mbpol_intra.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@

import sys
import numpy as np
import jax.numpy as jnp
from jax import grad, value_and_grad
from dmff.settings import DO_JIT
from dmff.utils import jit_condition
import numpy as np
from dmff.admp.spatial import v_pbc_shift
from dmff.admp.pme import ADMPPmeForce
from dmff.admp.parser import *
from dmff.utils import jit_condition
from jax import vmap
import time
#from admp.multipole import convert_cart2harm
#from jax_md import partition, space

#const
f5z = 0.999677885
fbasis = 0.15860145369897
fcore = -1.6351695982132
frest = 1.0
reoh = 0.958649;
thetae = 104.3475;
b1 = 2.0;
roh = 0.9519607159623009;
alphaoh = 2.587949757553683;
deohA = 42290.92019288289;
phh1A = 16.94879431193463;
phh2 = 12.66426998162947;
reoh = 0.958649
thetae = 104.3475
b1 = 2.0
roh = 0.9519607159623009
alphaoh = 2.587949757553683
deohA = 42290.92019288289
phh1A = 16.94879431193463
phh2 = 12.66426998162947

c5zA = jnp.array([4.2278462684916e+04, 4.5859382909906e-02, 9.4804986183058e+03,
7.5485566680955e+02, 1.9865052511496e+03, 4.3768071560862e+02,
Expand Down Expand Up @@ -467,12 +460,14 @@ def onebodyenergy(positions, box):

@vmap
@jit_condition(static_argnums={})
def onebody_kernel(x1, x2, x3, Va, Vb, efac):
def onebody_kernel(x1, x2, x3, Va, Vb, efac):
a = jnp.arange(-1,15)
a = a.at[0].set(0)
const = jnp.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
CONST = jnp.array([const,const,const])
list1 = jnp.array([x1**i for i in range(-1, 15)])
list2 = jnp.array([x2**i for i in range(-1, 15)])
list3 = jnp.array([x3**i for i in range(-1, 15)])
list1 = jnp.array([x1**i for i in a])
list2 = jnp.array([x2**i for i in a])
list3 = jnp.array([x3**i for i in a])
fmat = jnp.array([list1, list2, list3])
fmat *= CONST
F1 = jnp.sum(fmat[0].T * matrix1, axis=1) # fmat[0][inI] 1*245
Expand All @@ -488,42 +483,3 @@ def onebody_kernel(x1, x2, x3, Va, Vb, efac):
e1 *= cm1_kcalmol
e1 *= cal2joule # conver cal 2 j
return e1


def validation(pdb):
xml = 'mpidwater.xml'
pdbinfo = read_pdb(pdb)
serials = pdbinfo['serials']
names = pdbinfo['names']
resNames = pdbinfo['resNames']
resSeqs = pdbinfo['resSeqs']
positions = pdbinfo['positions']
box = pdbinfo['box'] # a, b, c, α, β, γ
charges = pdbinfo['charges']
positions = jnp.asarray(positions)
lx, ly, lz, _, _, _ = box
box = jnp.eye(3)*jnp.array([lx, ly, lz])

mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])

rc = 4 # in Angstrom
ethresh = 1e-4

n_atoms = len(serials)

# compute intra


grad_E1 = value_and_grad(onebodyenergy,argnums=(0))
ene, force = grad_E1(positions, box)
print(ene,force)
return


# below is the validation code
if __name__ == '__main__':
validation(sys.argv[1])


10 changes: 5 additions & 5 deletions dmff/admp/multipole.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
from functools import partial

import jax.numpy as jnp
from jax import vmap
from dmff.utils import jit_condition
from functools import partial
from jax import vmap

# This module deals with the transformations and rotations of multipoles

Expand Down Expand Up @@ -48,7 +48,7 @@ def convert_cart2harm(Theta, lmax):
n * (l+1)^2, stores the spherical multipoles
'''
if lmax > 2:
sys.exit('l > 2 (beyond quadrupole) not supported')
raise ValueError('l > 2 (beyond quadrupole) not supported')

Q_mono = Theta[0:1]

Expand Down Expand Up @@ -90,7 +90,7 @@ def convert_harm2cart(Q, lmax):
'''

if lmax > 2:
sys.exit('l > 2 (beyond quadrupole) not supported')
raise ValueError('l > 2 (beyond quadrupole) not supported')

T_mono = Q[0:1]

Expand Down
116 changes: 8 additions & 108 deletions dmff/admp/pairwise.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
from jax import vmap
from functools import partial

import jax.numpy as jnp
from dmff.utils import jit_condition, regularize_pairs, pair_buffer_scales
from dmff.admp.spatial import v_pbc_shift
from functools import partial
from dmff.utils import jit_condition, pair_buffer_scales, regularize_pairs
from jax import vmap

DIELECTRIC = 1389.35455846

Expand Down Expand Up @@ -40,6 +40,9 @@ def distribute_multipoles(multipoles, index):
def distribute_dispcoeff(c_list, index):
return c_list[index]

@jit_condition(static_argnums=())
def distribute_matrix(multipoles,index1,index2):
return multipoles[index1,index2]

def generate_pairwise_interaction(pair_int_kernel, covalent_map, static_args):
'''
Expand Down Expand Up @@ -130,7 +133,7 @@ def TT_damping_qq_kernel(dr, m, bi, bj, qi, qj):
@vmap
@jit_condition(static_argnums=())
def slater_disp_damping_kernel(dr, m, bi, bj, c6i, c6j, c8i, c8j, c10i, c10j):
'''
r'''
Slater-ISA type damping for dispersion:
f(x) = -e^{-x} * \sum_{k} x^k/k!
x = Br - \frac{2*(Br)^2 + 3Br}{(Br)^2 + 3*Br + 3}
Expand Down Expand Up @@ -167,106 +170,3 @@ def slater_sr_kernel(dr, m, ai, aj, bi, bj):
P = 1/3 * br2 + br + 1
return a * P * jnp.exp(-br) * m


def validation(pdb):
xml = 'mpidwater.xml'
pdbinfo = read_pdb(pdb)
serials = pdbinfo['serials']
names = pdbinfo['names']
resNames = pdbinfo['resNames']
resSeqs = pdbinfo['resSeqs']
positions = pdbinfo['positions']
box = pdbinfo['box'] # a, b, c, α, β, γ
charges = pdbinfo['charges']
positions = jnp.asarray(positions)
lx, ly, lz, _, _, _ = box
box = jnp.eye(3)*jnp.array([lx, ly, lz])

mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])

rc = 4 # in Angstrom
ethresh = 1e-4

n_atoms = len(serials)

atomTemplate, residueTemplate = read_xml(xml)
atomDicts, residueDicts = init_residues(serials, names, resNames, resSeqs, positions, charges, atomTemplate, residueTemplate)

covalent_map = assemble_covalent(residueDicts, n_atoms)
displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse)
nbr = neighbor_list_fn.allocate(positions)
pairs = nbr.idx.T

pmax = 10
kappa, K1, K2, K3 = setup_ewald_parameters(rc, ethresh, box)
kappa = 0.657065221219616

# construct the C list
c_list = np.zeros((3, n_atoms))
a_list = np.zeros(n_atoms)
q_list = np.zeros(n_atoms)
b_list = np.zeros(n_atoms)
nmol=int(n_atoms/3)
for i in range(nmol):
a = i*3
b = i*3+1
c = i*3+2
# dispersion coeff
c_list[0][a]=37.199677405
c_list[0][b]=7.6111103
c_list[0][c]=7.6111103
c_list[1][a]=85.26810658
c_list[1][b]=11.90220148
c_list[1][c]=11.90220148
c_list[2][a]=134.44874488
c_list[2][b]=15.05074749
c_list[2][c]=15.05074749
# q
q_list[a] = -0.741706
q_list[b] = 0.370853
q_list[c] = 0.370853
# b, Bohr^-1
b_list[a] = 2.00095977
b_list[b] = 1.999519942
b_list[c] = 1.999519942
# a, Hartree
a_list[a] = 458.3777
a_list[b] = 0.0317
a_list[c] = 0.0317


c_list = jnp.array(c_list)

# @partial(vmap, in_axes=(0, 0, 0, 0), out_axes=(0))
# @jit_condition(static_argnums=())
# def disp6_pme_real_kernel(dr, m, ci, cj):
# # unpack static arguments
# kappa = static_args['kappa']
# # calculate distance
# dr2 = dr ** 2
# dr6 = dr2 ** 3
# # do calculation
# x2 = kappa**2 * dr2
# exp_x2 = jnp.exp(-x2)
# x4 = x2 * x2
# g = (1 + x2 + 0.5*x4) * exp_x2
# return (m + g - 1) * ci * cj / dr6

# static_args = {'kappa': kappa}
# disp6_pme_real = generate_pairwise_interaction(disp6_pme_real_kernel, covalent_map, static_args)
# print(disp6_pme_real(positions, box, pairs, mScales, c_list[0, :]))

TT_damping_qq_c6 = generate_pairwise_interaction(TT_damping_qq_c6_kernel, covalent_map, static_args={})

TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0])
print('ok')
print(TT_damping_qq_c6(positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0]))
return


# below is the validation code
if __name__ == '__main__':
validation(sys.argv[1])
7 changes: 5 additions & 2 deletions dmff/admp/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import warnings
from collections import defaultdict
import jax.numpy as jnp
from dmff.admp.multipole import convert_cart2harm

def read_atom_line(line_full):
"""
Expand Down Expand Up @@ -326,7 +328,8 @@ def read_xml(fileobj):
set_axis_type(atomTemplates)

return atomTemplates, residueTemplates



class Atom:

def __init__(self, serial, name, resName, resSeq, position, charge, ) -> None:
Expand Down Expand Up @@ -474,4 +477,4 @@ def assemble_covalent(residueDicts, natoms):
covalent_map[c][pp] = dr

return covalent_map

Loading

0 comments on commit 846e3b8

Please sign in to comment.