Skip to content

Commit

Permalink
Merge pull request #169 from dingye18/octopole
Browse files Browse the repository at this point in the history
Implementation of the Octupole Support for ADMP
  • Loading branch information
KuangYu authored Apr 7, 2024
2 parents e299ced + 20cdbfd commit 8c4c767
Show file tree
Hide file tree
Showing 14 changed files with 1,816 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
conda create -n dmff -y python=${{ matrix.python-version }} numpy openmm==7.7.0 pytest rdkit openbabel mdtraj ambertools -c conda-forge
conda activate dmff
pip install --upgrade pip
pip install jax jaxlib jaxopt networkx parmed pymbar==4.0.1 optax tqdm
pip install jax==0.4.24 jaxlib==0.4.24 jaxopt networkx parmed pymbar==4.0.1 optax tqdm
- name: Install DMFF
run: |
source $CONDA/bin/activate dmff && pip install .
Expand Down
7 changes: 6 additions & 1 deletion backend/openmm_dmff_plugin/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ execute_process(
OUTPUT_VARIABLE GIT_HASH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
execute_process(
COMMAND git describe --tags --abbrev=0
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
OUTPUT_VARIABLE GIT_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE
)

# Compile the Python module.
add_custom_target(PythonInstall DEPENDS "${WRAP_FILE}")
Expand All @@ -36,4 +42,3 @@ add_custom_command(TARGET PythonInstall
COMMAND "${PYTHON_EXECUTABLE}" setup.py install
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
)

2 changes: 1 addition & 1 deletion backend/openmm_dmff_plugin/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


setup(name='OpenMMDMFFPlugin',
version="@GIT_HASH@",
version="@GIT_VERSION@".lstrip('v').replace('-', ''),
ext_modules=[extension],
packages=['OpenMMDMFFPlugin', "OpenMMDMFFPlugin.tests"],
package_data={"OpenMMDMFFPlugin":['data/lj_fluid/*.pb', 'data/lj_fluid/variables/variables.index', 'data/lj_fluid/variables/variables.data-00000-of-00001', 'data/lj_fluid_gpu/*.pb', 'data/lj_fluid_gpu/variables/variables.index', 'data/lj_fluid_gpu/variables/variables.data-00000-of-00001', 'data/*.pdb']},
Expand Down
141 changes: 132 additions & 9 deletions dmff/admp/multipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@
# The important conversion matrices used in multipoles
rt3 = 1.73205080757
inv_rt3 = 1.0/rt3
rt2 = 1.41421356237
inv_rt2 = 1.0/rt2
rt5 = 2.2360679775
inv_rt5 = 1.0/rt5
rt6 = 2.44948974278
inv_rt6 = 1.0/rt6
rt10 = 3.16227766017
inv_rt10 = 1.0/rt10
rt8 = 2.82842712475
rt12 = 3.46410161514
rt15 = 3.87298334621
rt24 = 4.89897948557
inv_rt24 = 1.0/rt24


# the dipole conversion matrices, cart2harm and harm2cart
C1_h2c = jnp.array([[0, 1, 0],
[0, 0, 1],
Expand All @@ -29,6 +44,29 @@
[ 0, 0, 0, 0, rt3/2],
[ 0, rt3/2, 0, 0, 0],
[ 0, 0, rt3/2, 0, 0]])
# the octupole conversion matrices
C3_c2h = jnp.array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[ 0, 0, 0, 0, 0, 0, 0, rt3/rt2, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, rt3/rt2, 0],
[ 0, 0, 0, 0, rt3/rt5, 0, -rt3/rt5, 0, 0, 0],
[ 0, 0, 0, 0, 0, 2*rt3/rt5, 0, 0, 0, 0],
[ inv_rt10, 0, -3*inv_rt10, 0, 0, 0, 0, 0, 0, 0],
[ 0, 3*inv_rt10, 0, -inv_rt10, 0, 0, 0, 0, 0, 0]])


C3_h2c = jnp.array([[ 0, -rt3/rt8, 0, 0, 0, rt5/rt8, 0],
[ 0, 0, -inv_rt24, 0, 0, 0, rt5/rt8],
[ 0, -inv_rt24, 0, 0, 0, -rt5/rt8, 0],
[ 0, 0, -rt3/rt8, 0, 0, 0, -rt5/rt8],
[ -0.5, 0, 0, rt5/rt12, 0, 0, 0],
[ 0, 0, 0, 0, rt5/rt12, 0, 0],
[ -0.5, 0, 0, -rt5/rt12, 0, 0, 0],
[ 0, rt2/rt3, 0, 0, 0, 0, 0],
[ 0, 0, rt2/rt3, 0, 0, 0, 0],
[ 1, 0, 0, 0, 0, 0, 0]])





@partial(vmap, in_axes=(0, None), out_axes=0)
Expand All @@ -48,8 +86,8 @@ def convert_cart2harm(Theta, lmax):
Q:
n * (l+1)^2, stores the spherical multipoles
'''
if lmax > 2:
raise ValueError('l > 2 (beyond quadrupole) not supported')
if lmax > 3:
raise ValueError('l > 3 (beyond octupole) not supported')

Q_mono = Theta[0:1]

Expand All @@ -61,13 +99,21 @@ def convert_cart2harm(Theta, lmax):
if lmax >= 2:
quad_cart = Theta[4:10].T
Q_quad = C2_c2h.dot(quad_cart).T

# octupole
if lmax >= 3:
octu_cart = Theta[10:20].T
Q_octu = C3_c2h.dot(octu_cart).T

if lmax == 0:
Q = Q_mono
elif lmax == 1:
Q = jnp.hstack([Q_mono, Q_dip])
else:
elif lmax == 2:
Q = jnp.hstack([Q_mono, Q_dip, Q_quad])
elif lmax == 3:
Q = jnp.hstack([Q_mono, Q_dip, Q_quad, Q_octu])
else:
raise ValueError('l > 3 (beyond octupole) not supported')

return Q

Expand All @@ -90,8 +136,8 @@ def convert_harm2cart(Q, lmax):
n * n_cart, stores the cartesian multipoles
'''

if lmax > 2:
raise ValueError('l > 2 (beyond quadrupole) not supported')
if lmax > 3:
raise ValueError('l > 3 (beyond octupole) not supported')

T_mono = Q[0:1]

Expand All @@ -101,13 +147,20 @@ def convert_harm2cart(Q, lmax):
# quadrupole
if lmax >= 2:
T_quad = C2_h2c.dot(Q[4:9].T).T
# octupole
if lmax >= 3:
T_octu = C3_h2c.dot(Q[9:16].T).T

if lmax == 0:
T = T_mono
elif lmax == 1:
T = jnp.hstack([T_mono, T_dip])
else:
elif lmax == 2:
T = jnp.hstack([T_mono, T_dip, T_quad])
elif lmax == 3:
T = jnp.hstack([T_mono, T_dip, T_quad, T_octu])
else:
raise ValueError('l > 3 (beyond octupole) not supported')

return T

Expand Down Expand Up @@ -143,8 +196,8 @@ def rot_global2local(Q_gh, localframes, lmax=2):
Q_lh:
n * (l+1)^2, stores the local harmonic multipole moments
'''
if lmax > 2:
raise NotImplementedError('l > 2 (beyond quadrupole) not supported')
if lmax > 3:
raise NotImplementedError('l > 3 (beyond octupole) not supported')

# monopole
Q_lh_0 = Q_gh[0:1]
Expand Down Expand Up @@ -204,12 +257,82 @@ def rot_global2local(Q_gh, localframes, lmax=2):
]
).swapaxes(0,1)
Q_lh_2 = jnp.einsum('jk,k->j', C2_gl, quadrupoles)

if lmax >= 3:
octupoles = Q_gh[9:16]
C3_gl_00 = ( -8 * xx * yy + 8 * yx * xy + 5 * zz **3 + 5 * zz) / 2
C3_gl_01 = (rt6*zx*(5.0*zz**2-1.0))/4.0
C3_gl_02 = (rt6*zy*(5.0*zz**2-1.0))/4.0
C3_gl_03 = (rt15*zz*(-2.0*zy**2-zz**2+1.0))/2.0
C3_gl_04 = rt15*zx*zy*zz
C3_gl_05 = (rt10*zx*(-4.0*zy**2-zz**2+1.0))/4.0
C3_gl_06 = (rt10*zy*(-4.0*zy**2-3.0*zz**2+3.0))/4.0
C3_gl_10 = (rt3*xz*(5.0*zz**2-1.0))/(2.0*rt2)
C3_gl_11 = (-10.0*xx*yy**2+15.0*xx*zz**2-xx+10.0*yx*xy*yy)/4.0
C3_gl_12 = (10.0*xy*yz**2+15.0*xy*zz**2-11.0*xy-10.0*yy*xz*yz)/4.0
C3_gl_13 = (rt10*(4.0*xy*yy*yz-4.0*yy**2*xz-6.0*zy**2*xz-3.0*xz*zz**2+5.0*xz))/4.0
C3_gl_14 = rt10*(-xx*yy*yz-yx*xy*yz+2.0*yx*yy*xz+3.0*zx*zy*xz)/2.0
C3_gl_15 = (rt15*(-2.0*xx*yy**2-4.0*xx*zy**2-xx*zz**2+3.0*xx+2.0*yx*xy*yy))/4.0
C3_gl_16 = (rt15*(-4.0*xy*zy**2-2.0*xy*yz**2-3.0*xy*zz**2+3.0*xy+2.0*yy*xz*yz))/4.0
C3_gl_20 = (rt3*yz*(5.0*zz**2-1.0))/(2.0*rt2)
C3_gl_21 = (10.0*yx*zy**2+15.0*yx*zz**2-11.0*yx-10.0*zx*yy*zy)/4.0
C3_gl_22 = (5.0*yy*zz**2-yy+10.0*zy*yz*zz)/4.0
C3_gl_23 = (rt10*(-4.0*yy*zy*zz-2.0*zy**2*yz-3.0*yz*zz**2+yz))/4.0
C3_gl_24 = (rt10*(yx*zy*zz+zx*yy*zz+zx*zy*yz))/2.0
C3_gl_25 = (rt15*(-2.0*yx*zy**2-yx*zz**2+yx-2.0*zx*yy*zy))/4.0
C3_gl_26 = (rt15*(-4.0*yy*zy**2-yy*zz**2+yy-2.0*zy*yz*zz))/4.0
C3_gl_30 = rt15*zz*(-2.0*yz**2-zz**2+1.0)/2.0
C3_gl_31 = (rt10*(4.0*yx*yy*zy-4.0*zx*yy**2-6.0*zx*yz**2-3.0*zx*zz**2+5.0*zx))/4.0
C3_gl_32 = (rt10*(-4.0*yy*yz*zz-2.0*zy*yz**2-3.0*zy*zz**2+zy))/4.0
C3_gl_33 = (-4.0*xx*yy-4.0*yx*xy+12.0*yy**2*zz+6.0*zy**2*zz+6.0*yz**2*zz+3.0*zz**3-9.0*zz)/2.0
C3_gl_34 = -6.0*yx*yy*zz-3.0*zx*zy*zz-4.0*xy*yy-2.0*xz*yz
C3_gl_35 = (rt6*(4.0*yx*yy*zy+4.0*zx*yy**2+4.0*zx*zy**2+2.0*zx*yz**2+zx*zz**2-3.0*zx))/4.0
C3_gl_36 = (rt6*(8.0*yy**2*zy+4.0*yy*yz*zz+4.0*zy**3+2.0*zy*yz**2+3.0*zy*zz**2-5.0*zy))/4.0
C3_gl_40 = rt15*xz*yz*zz
C3_gl_41 = (rt10*(-xx*yy*zy-yx*xy*zy+2.0*zx*xy*yy+3.0*zx*xz*yz))/2.0
C3_gl_42 = (rt10*(xy*yz*zz+yy*xz*zz+zy*xz*yz))/2.0
C3_gl_43 = -4.0*yx*yy-2.0*zx*zy-6.0*xy*yy*zz-3.0*xz*yz*zz
C3_gl_44 = 3.0*xx*yy*zz+3.0*yx*xy*zz-4.0*yy**2-2.0*zy**2-2.0*yz**2-zz**2+3.0
C3_gl_45 = (rt6*(-xx*yy*zy-yx*xy*zy-2.0*zx*xy*yy-zx*xz*yz))/2.0
C3_gl_46 = (rt6*(-4.0*xy*yy*zy-xy*yz*zz-yy*xz*zz-zy*xz*yz))/2.0
C3_gl_50 = (rt5*xz*(-4.0*yz**2-zz**2+1.0))/(2.0*rt2)
C3_gl_51 = (rt15*(-2.0*xx*yy**2-4.0*xx*yz**2-xx*zz**2+3.0*xx+2.0*yx*xy*yy))/4.0
C3_gl_52 = (rt15*(-2.0*xy*yz**2-xy*zz**2+xy-2.0*yy*xz*yz))/4.0
C3_gl_53 = (rt6*(4.0*xy*yy*yz+4.0*yy**2*xz+2.0*zy**2*xz+4.0*xz*yz**2+xz*zz**2-3.0*xz))/4.0
C3_gl_54 = (rt6*(-xx*yy*yz-yx*xy*yz-2.0*yx*yy*xz-zx*zy*xz))/2.0
C3_gl_55 = (10.0*xx*yy**2+4.0*xx*zy**2+4.0*xx*yz**2+xx*zz**2-7.0*xx+6.0*yx*xy*yy)/4.0
C3_gl_56 = (16.0*xy*yy**2+4.0*xy*zy**2+6.0*xy*yz**2+3.0*xy*zz**2-7.0*xy+6.0*yy*xz*yz)/4.0
C3_gl_60 = (rt5*yz*(-4.0*yz**2-3.0*zz**2+3.0))/(2.0*rt2)
C3_gl_61 = (rt15*(-2.0*yx*zy**2-4.0*yx*yz**2-3.0*yx*zz**2+3.0*yx+2.0*zx*yy*zy))/4.0
C3_gl_62 = (rt15*(-4.0*yy*yz**2-yy*zz**2+yy-2.0*zy*yz*zz))/4.0
C3_gl_63 = (rt6*(8.0*yy**2*yz+4.0*yy*zy*zz+2.0*zy**2*yz+4.0*yz**3+3.0*yz*zz**2-5.0*yz))/4.0
C3_gl_64 = (rt6*(-4.0*yx*yy*yz-yx*zy*zz-zx*yy*zz-zx*zy*yz))/2.0
C3_gl_65 = (16.0*yx*yy**2+6.0*yx*zy**2+4.0*yx*yz**2+3.0*yx*zz**2-7.0*yx+6.0*zx*yy*zy)/4.0
C3_gl_66 = (16.0*yy**3+12.0*yy*zy**2+12.0*yy*yz**2+3.0*yy*zz**2-15.0*yy+6.0*zy*yz*zz)/4.0

# rotate
C3_gl = jnp.array(
[
[C3_gl_00, C3_gl_10, C3_gl_20, C3_gl_30, C3_gl_40, C3_gl_50, C3_gl_60],
[C3_gl_01, C3_gl_11, C3_gl_21, C3_gl_31, C3_gl_41, C3_gl_51, C3_gl_61],
[C3_gl_02, C3_gl_12, C3_gl_22, C3_gl_32, C3_gl_42, C3_gl_52, C3_gl_62],
[C3_gl_03, C3_gl_13, C3_gl_23, C3_gl_33, C3_gl_43, C3_gl_53, C3_gl_63],
[C3_gl_04, C3_gl_14, C3_gl_24, C3_gl_34, C3_gl_44, C3_gl_54, C3_gl_64],
[C3_gl_05, C3_gl_15, C3_gl_25, C3_gl_35, C3_gl_45, C3_gl_55, C3_gl_65],
[C3_gl_06, C3_gl_16, C3_gl_26, C3_gl_36, C3_gl_46, C3_gl_56, C3_gl_66]
]
).swapaxes(0,1)

Q_lh_3 = jnp.einsum('jk,k->j', C3_gl, octupoles)

if lmax == 0:
Q_lh = Q_lh_0
elif lmax == 1:
Q_lh = jnp.hstack([Q_lh_0, Q_lh_1])
elif lmax == 2:
Q_lh = jnp.hstack([Q_lh_0, Q_lh_1, Q_lh_2])
else:
Q_lh = jnp.hstack([Q_lh_0, Q_lh_1, Q_lh_2, Q_lh_3])

return Q_lh

Expand Down
Loading

0 comments on commit 8c4c767

Please sign in to comment.