From 488543e8097e171e8e00734c9a7d1e88bbe31099 Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:09:25 +0800 Subject: [PATCH 01/15] Update graph.py Debug the code for dmff matching the template when force field xml file specify the same atom element type for different atoms --- dmff/api/graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dmff/api/graph.py b/dmff/api/graph.py index 5e4085581..0d56cd37d 100644 --- a/dmff/api/graph.py +++ b/dmff/api/graph.py @@ -21,7 +21,7 @@ def matchTemplate(graph, template): return False, {}, {} def match_func(n1, n2): - return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"] + return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"] and n1["name"] == n2["name"] def edge_match(e1, e2): if len(e1) == 0 and len(e2) == 0: From fb282cfc01b1f92859758f0089465c17180b457a Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:16:23 +0800 Subject: [PATCH 02/15] Update inter.py --- dmff/classical/inter.py | 68 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index d6f393783..f3229d02f 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -332,3 +332,71 @@ def get_energy(positions, box, pairs, bcc, mscales): charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten() return get_energy_kernel(positions, box, pairs, charges, mscales) return get_energy + + +class CustomGBForce: + # E_{GB} = -\frac12(\frac1{\epsilon_{solute}}-\frac1{\epsilon_{solvent}})\sum_{i, j}\frac{q_iq_j}{f_{GB}(d_{ij}, R_i, R_j)} + # f_{GB}(d_{ij}, R_i, R_j)=[d_{ij} ^ 2 + R_iR_jexp(\frac{-d_{ij} ^ 2}{4R_iR_j})] ^ {1 / 2} + # R_i=\frac1{\rho_i^{-1}-r_i^{-1}tanh(\alpha\Psi_i-\beta\Psi_i^2+\gamma\Psi_i^3)} + # \alpha=1,\beta=0.8,\gamma=4.85,\rho_i=r_i-0.009nm + # \Psi_i=\frac{\rho_i}{4\pi}\int_{VDW}\theta(|r|-\rho_i)\frac1{|r|^4}d^3r + # E_{SAT}=E_{SA}\cdot4\pi\sum_i(r_i+r_{solvent})^2(\frac{r_i}{R_i})^6 + def __init__( + self, + map_charge, + map_radius, + map_scale, + epsilon_1=1.0, + epsilon_solv=78.3, + alpha=1, + beta=0.8, + gamma=4.85, + ) -> None: + self.map_charge = map_charge + self.map_radius = map_radius + self.map_scale = map_scale + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.exp_solv = epsilon_solv + self.eps_1 = epsilon_1 + + def generate_get_energy(self): + @jax.jit + def get_energy(positions, box, pairs, Ipairs, charges, radius, scales): + def calI(posList, radMap, scalMap, rhoMap, pairMap): + I = jnp.array([]) + for i in range(len(radMap)): + posj = posList[Ipairs[i]] + rhoj = rhoMap[Ipairs[i]] + scalj = scalMap[Ipairs[i]] + posi = posList[i] + rhoi = rhoMap[i] + r = jnp.sqrt(jnp.sum(jnp.power(posi-posj,2),axis=1)) + sr2 = rhoj * scalj + D = jnp.abs(r - sr2) + L = jnp.maximum(D, rhoi) + C = 2 * (1 / rhoi - 1 / L) * jnp.heaviside(sr2 - r - rhoi, 1) + U = r + sr2 + I = jnp.append(I, jnp.sum(0.5 * jnp.heaviside(r + sr2 - rhoi, 1) * ( + 1 / L - 1 / U + 0.25 * (1 / U ** 2 - 1 / L ** 2) * ( + r - sr2 ** 2 / r) + 0.5 * jnp.log(L / U) / r + C))) + + return I + + chargeMap = charges[self.map_charge] + radiusMap = radius[self.map_radius] + scalesMap = scales[self.map_scale] + rhoMap = radiusMap - 0.009 + + # effective radius + IList = calI(positions, radiusMap, scalesMap, rhoMap, Ipairs) + psi = IList*rhoMap + rEff = 1/(1/rhoMap-jnp.tanh(self.alpha*psi-self.beta*jnp.power(psi, 2)+self.gamma*jnp.power(psi, 3))/radiusMap) + Ese = jnp.sum(28.3919551*(radiusMap+0.14)**2*jnp.power(radiusMap/rEff, 6)-0.5*138.935456*(1/self.eps_1-1/self.exp_solv)*chargeMap**2/rEff) + dr_norm = jnp.linalg.norm(positions[pairs[:,0]] - positions[pairs[:,1]], axis=1) + chargepro = chargeMap[pairs[:, 0]] * chargeMap[pairs[:, 1]] + rEffpro = rEff[pairs[:, 0]] * rEff[pairs[:, 1]] + Egb = jnp.sum(-138.935456*(1/self.eps_1-1/self.exp_solv)*chargepro/jnp.sqrt(jnp.power(dr_norm, 2)+rEffpro*jnp.exp(-jnp.power(dr_norm,2)/(4*rEffpro)))) + return Ese + Egb + return get_energy From e552874d84cb377d71771aa7998dc4a0a9dde957 Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:17:14 +0800 Subject: [PATCH 03/15] Update intra.py --- dmff/classical/intra.py | 75 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/dmff/classical/intra.py b/dmff/classical/intra.py index 0625f48ee..32cc4c738 100644 --- a/dmff/classical/intra.py +++ b/dmff/classical/intra.py @@ -136,3 +136,78 @@ def refresh_calculators(self): """ self.get_energy = self.generate_get_energy() self.get_forces = value_and_grad(self.get_energy) + + +class Custom1_5BondJaxForce: + def __init__(self, p1idx, p2idx, prmidx): + self.p1idx = p1idx + self.p2idx = p2idx + self.prmidx = prmidx + self.refresh_calculators() + + def generate_get_energy(self): + def get_energy(positions, box, pairs, k, length): + p1 = positions[self.p1idx,:] + p2 = positions[self.p2idx,:] + kprm = k[self.prmidx] + b0prm = length[self.prmidx] + dist = distance(p1, p2) + return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2)) + + return get_energy + + def update_env(self, attr, val): + """ + Update the environment of the calculator + """ + setattr(self, attr, val) + self.refresh_calculators() + + def refresh_calculators(self): + """ + refresh the energy and force calculators according to the current environment + """ + self.get_energy = self.generate_get_energy() + self.get_forces = value_and_grad(self.get_energy) + + +class CustomTorsionJaxForce: + def __init__(self, p1idx, p2idx, p3idx, p4idx, prmidx, order): + self.p1idx = p1idx + self.p2idx = p2idx + self.p3idx = p3idx + self.p4idx = p4idx + self.prmidx = prmidx + self.order = order + self.refresh_calculators() + + def generate_get_energy(self): + if len(self.p1idx) == 0: + return lambda positions, box, pairs, k, psi, shift: 0.0 + def get_energy(positions, box, pairs, k, psi, shift): + p1 = positions[self.p1idx, :] + p2 = positions[self.p2idx, :] + p3 = positions[self.p3idx, :] + p4 = positions[self.p4idx, :] + kp = k[self.prmidx] + psip = psi[self.prmidx] + shiftp = shift[self.prmidx] + dih = dihedral(p1, p2, p3, p4) + ener = kp * (jnp.cos(self.order * dih - psip)) + shiftp + return jnp.sum(ener) + + return get_energy + + def update_env(self, attr, val): + """ + Update the environment of the calculator + """ + setattr(self, attr, val) + self.refresh_calculators() + + def refresh_calculators(self): + """ + refresh the energy and force calculators according to the current environment + """ + self.get_energy = self.generate_get_energy() + self.get_forces = value_and_grad(self.get_energy) From a0120dac31d3e042eab3e430e63f1144adfc71cd Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:20:16 +0800 Subject: [PATCH 04/15] Update classical.py --- dmff/generators/classical.py | 716 ++++++++++++++++++++++++++++++++++- 1 file changed, 714 insertions(+), 2 deletions(-) diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index b330c87e2..f2993ea9b 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -7,8 +7,8 @@ import jax.numpy as jnp import openmm.app as app import openmm.unit as unit -from ..classical.intra import HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce -from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce, CoulReactionFieldForce, LennardJonesForce, LennardJonesLongRangeForce +from ..classical.intra import HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce, , Custom1_5BondJaxForce, CustomTorsionJaxForce +from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce, CoulReactionFieldForce, LennardJonesForce, LennardJonesLongRangeForce, CustomGBForce from typing import Tuple, List, Union, Callable @@ -1398,3 +1398,715 @@ def potential_fn(positions, box, pairs, params, aux=None): _DMFFGenerators["LennardJonesForce"] = LennardJonesGenerator + + +class Custom1_5BondGenerator: + """ + A class for generating harmonic bond force field parameters. + + Attributes: + ----------- + name : str + The name of the force field. + ffinfo : dict + The force field information. + key_type : str + The type of the key. + bond_keys : list of tuple + The keys of the bonds. + bond_params : list of tuple + The parameters of the bonds. + bond_mask : list of float + The mask of the bonds. + """ + + def __init__(self, ffinfo: dict, paramset: ParamSet): + """ + Initializes the HarmonicBondGenerator. + + Parameters: + ----------- + ffinfo : dict + The force field information. + paramset : ParamSet + The parameter set. + """ + self.name = "Custom1_5BondForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self.key_type = None + + bond_keys, bond_params = [], [] + for node in self.ffinfo["Forces"][self.name]["node"]: + attribs = node["attrib"] + if self.key_type is None: + if "atomIndex1" in attribs: + self.key_type = "atomIndex" + else: + raise ValueError( + "Cannot find key type for Custom1_5BondForce.") + key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"]) + bond_keys.append(key) + k = float(attribs["k"]) + r0 = float(attribs["length"]) + bond_params.append([k, r0]) + + self.bond_keys = bond_keys + bond_length = jnp.array([i[1] for i in bond_params]) + bond_k = jnp.array([i[0] for i in bond_params]) + + # register parameters to ParamSet + paramset.addParameter(bond_length, "length", + field=self.name) + # register parameters to ParamSet + paramset.addParameter(bond_k, "k", field=self.name) + + def getName(self) -> str: + """ + Returns the name of the force field. + + Returns: + -------- + str + The name of the force field. + """ + return self.name + + def overwrite(self, paramset: ParamSet) -> None: + """ + Overwrites the parameter set. + + Parameters: + ----------- + paramset : ParamSet + The parameter set. + """ + bond_node_indices = [i for i in range(len( + self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Bond"] + + bond_length = paramset[self.name]["length"] + bond_k = paramset[self.name]["k"] + for nnode, key in enumerate(self.bond_keys): + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode]]["attrib"] = { + } + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode] + ]["attrib"][f"{self.key_type}1"] = key[0] + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode] + ]["attrib"][f"{self.key_type}2"] = key[1] + r0 = bond_length[nnode] + k = bond_k[nnode] + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode] + ]["attrib"]["k"] = str(k) + self.ffinfo["Forces"][self.name]["node"][bond_node_indices[nnode] + ]["attrib"]["length"] = str(r0) + + def _find_key_index(self, key: Tuple[str, str]) -> int: + """ + Finds the index of the key. + + Parameters: + ----------- + key : tuple of str + The key. + + Returns: + -------- + int + The index of the key. + """ + for i, k in enumerate(self.bond_keys): + if k[0] == key[0] and k[1] == key[1]: + return i + if k[0] == key[1] and k[1] == key[0]: + return i + return None + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, + nonbondedCutoff, **kwargs): + """ + Creates the potential. + + Parameters: + ----------- + topdata : DMFFTopology + The topology data. + nonbondedMethod : str + The nonbonded method. + nonbondedCutoff : float + The nonbonded cutoff. + args : list + The arguments. + + Returns: + -------- + function + The potential function. + """ + bond_a1, bond_a2, bond_indices = [], [], [] + for i, k in enumerate(self.bond_keys): + bond_a1.append(int(k[0])) + bond_a2.append(int(k[1])) + bond_indices.append(int(i)) + bond_a1 = jnp.array(bond_a1) + bond_a2 = jnp.array(bond_a2) + bond_indices = jnp.array(bond_indices) + + # 创建势函数 + harmonic_bond_force = HarmonicBondJaxForce( + bond_a1, bond_a2, bond_indices) + harmonic_bond_energy = harmonic_bond_force.generate_get_energy() + + has_aux = False + if "has_aux" in kwargs and kwargs["has_aux"]: + has_aux = True + + def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None): + isinstance_jnp(positions, box, params) + energy = harmonic_bond_energy( + positions, box, pairs, params[self.name]["k"], params[self.name]["length"]) + if has_aux: + return energy, aux + else: + return energy + + self._jaxPotential = potential_fn + return potential_fn + + +# register the generator +_DMFFGenerators["Custom1_5BondForce"] = Custom1_5BondGenerator + + +class CustomGBGenerator: + """ + A class for generating Custom Generalized Born implicit solvation models. + The following code implements the OBC variant of the GB/SA solvation model, using the ACE approximation to estimate surface area. + + Attributes: + ----------- + name : str + The name of the force field. + ffinfo : dict + The force field information. + key_type : str + The type of the key. + perParticleKey : list of tuple + The keys of the atoms + + """ + + def __init__(self, ffinfo: dict, paramset: ParamSet): + """ + Initialize the CustomGBForceGenerator + + Parameters: + ----------- + ffinfo : dict + The force field information. + paramset : ParamSet + The parameter set. + + """ + self.name = "CustomGBForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self.key_type = None + self.perParticleParamIndices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"] + + perParticleKey, perParticleParam, chargeMask = [], [], [] + for i in self.perParticleParamIndices: + attribs = self.ffinfo["Forces"][self.name]["node"][i]["attrib"] + if self.key_type is None: + if "type" in attribs: + self.key_type = "type" + elif "class" in attribs: + self.key_type = "class" + else: + raise ValueError( + "Cannot find key type for CustomGBForce." + ) + key = (attribs[self.key_type]) + perParticleKey.append(key) + + charge = float(attribs["charge"]) + radius = float(attribs["radius"]) + scale = float(attribs["scale"]) + + # Parameter Charge is not trainable + chargeMask.append(0.0) + perParticleParam.append([charge, radius, scale]) + + self.perParticleKey = perParticleKey + paramset.addParameter(jnp.array([i[0] for i in perParticleParam]), + "charge", field=self.name, mask=chargeMask) + paramset.addParameter(jnp.array([i[1] for i in perParticleParam]), + "radius", field=self.name) + paramset.addParameter(jnp.array([i[2] for i in perParticleParam]), + "scale", field=self.name) + + + def getName(self) -> str: + """ + Returns the name of the force field. + + Returns: + -------- + str + The name of the force field. + """ + return self.name + + + def overwrite(self, paramset: ParamSet) -> None: + """ + Overwrites the parameter set. + + Parameters: + ----------- + paramset : ParamSet + The parameter set. + """ + radius = paramset[self.name]["radius"] + scale = paramset[self.name]["scale"] + for i in self.ffinfo.perParticleParamIndices: + self.ffinfo["Forces"][self.name]["node"][i]["attrib"]["radius"] = str(radius[i]) + self.ffinfo["Forces"][self.name]["node"][i]["attrib"]["scale"] = str(scale[i]) + + def _find_key_index(self, key: Tuple[str]) -> int: + """ + Finds the index of the key. + + Parameters: + ----------- + key : tuple of str + The key. + + Returns: + -------- + int + The index of the key. + """ + for i, k in enumerate(self.perParticleKey): + if k == key: + return i + return None + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs): + """ + Creates the potential. + + Parameters: + ----------- + topdata : DMFFTopology + The topology data. + nonbondedMethod : str + The nonbonded method. + nonbondedCutoff : float + The nonbonded cutoff. + args : list + The arguments. + + Returns: + -------- + function + The potential function. + """ + # Load CustomGBForce parameters + charge_indices, radius_indices, scale_indices = [], [] ,[] + for atom in topdata.atoms(): + if self.key_type == "type": + key = (atom.meta["type"]) + elif self.key_type == "class": + key = (atom.meta["class"]) + idx = self._find_key_index(key) + if idx is None: + continue + charge_indices.append(idx) + radius_indices.append(idx) + scale_indices.append(idx) + + charge_indices = jnp.array(charge_indices) + radius_indices = jnp.array(radius_indices) + scale_indices = jnp.array(scale_indices) + + customGBforce = CustomGBForce(charge_indices, radius_indices, scale_indices) + GBSAOBCenergy = customGBforce.generate_get_energy() + + def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet): + pairs = pairs[:int(positions.shape[0]*(positions.shape[0]-1)/2)] + tt = np.vstack((pairs, pairs[:,[1, 0, 2]])) + Ipair = [] + for i in range(positions.shape[0]): + Ipair.append([pair[1] for pair in tt if pair[0] == i]) + Ipair = jnp.array(Ipair) + energy = GBSAOBCenergy(positions, box, pairs, Ipair, + params[self.name]["charge"], + params[self.name]["radius"], + params[self.name]['scale']) + return energy + + self._jaxPotential = potential_fn + return potential_fn + + +_DMFFGenerators["CustomGBForce"] = CustomGBGenerator + + +class CustomTorsionGenerator: + + def __init__(self, ffinfo: dict, paramset: ParamSet): + """ + Initializes a PeriodicTorsionForce object. + + Args: + - ffinfo (dict): A dictionary containing force field information. + - paramset (ParamSet): A ParamSet object to register parameters. + + Returns: + - None + """ + self.name = "CustomTorsionForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self._use_smarts = False + self.key_type = None + self.torsionIndices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if "roper" in self.ffinfo["Forces"][self.name]["node"][i]["name"]] + + proper_keys, proper_periods, proper_prms, proper_shift = [], [], [], [] + proper_key_to_prms = {} + improper_keys, improper_periods, improper_prms, improper_shift = [], [], [], [] + improper_key_to_prms = {} + for i in self.torsionIndices: + node = self.ffinfo["Forces"][self.name]["node"][i] + attribs = node["attrib"] + if "type1" in attribs: + self.key_type = "type" + elif "class1" in attribs: + self.key_type = "class" + key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"], + attribs[self.key_type + "3"], attribs[self.key_type + "4"]) + if node["name"] == "Proper": + proper_keys.append(key) + elif node["name"] == "Improper": + improper_keys.append(key) + + mask = 1.0 + if "mask" in attribs and attribs["mask"].upper() == "TRUE": + mask = 0.0 + + for period_key in attribs.keys(): + if "per" not in period_key: + continue + order = int(period_key.replace("per", "")) + period = int(attribs[period_key]) + phase = float(attribs["phase" + str(order)]) + k = float(attribs["k" + str(order)]) + shift = float(attribs["shift"])/4 + if node["name"] == "Proper": + proper_periods.append(period) + proper_prms.append([phase, k, mask, shift]) + if len(proper_keys) - 1 not in proper_key_to_prms: + proper_key_to_prms[len(proper_keys) - 1] = [] + proper_key_to_prms[len( + proper_keys) - 1].append(len(proper_periods) - 1) + elif node["name"] == "Improper": + improper_periods.append(period) + improper_prms.append([phase, k, mask, shift]) + if len(improper_keys) - 1 not in improper_key_to_prms: + improper_key_to_prms[len(improper_keys) - 1] = [] + improper_key_to_prms[len( + improper_keys) - 1].append(len(improper_periods) - 1) + + self.proper_keys = proper_keys + self.proper_periods = jnp.array(proper_periods) + self.proper_key_to_prms = proper_key_to_prms + proper_phase = jnp.array([i[0] for i in proper_prms]) + proper_k = jnp.array([i[1] for i in proper_prms]) + proper_mask = jnp.array([i[2] for i in proper_prms]) + proper_shift = jnp.array([i[3] for i in proper_prms]) + # register parameters to ParamSet + paramset.addParameter(proper_phase, "proper_phase", + field=self.name, mask=proper_mask) + paramset.addParameter(proper_k, "proper_k", + field=self.name, mask=proper_mask) + paramset.addParameter(proper_shift, "proper_shift", + field=self.name, mask=proper_mask) + + self.imp_keys = improper_keys + self.imp_periods = jnp.array(improper_periods) + self.imp_key_to_prms = improper_key_to_prms + improper_phase = jnp.array([i[0] for i in improper_prms]) + improper_k = jnp.array([i[1] for i in improper_prms]) + improper_mask = jnp.array([i[2] for i in improper_prms]) + improper_shift = jnp.array([i[3] for i in improper_prms]) + # register parameters to ParamSet + paramset.addParameter(improper_phase, "improper_phase", + field=self.name, mask=improper_mask) + paramset.addParameter(improper_k, "improper_k", + field=self.name, mask=improper_mask) + paramset.addParameter(improper_shift, "improper_shift", + field=self.name, mask=improper_mask) + + def getName(self): + return self.name + + def overwrite(self, paramset): + # paramset to ffinfo + proper_node_indices = [i for i in range(len( + self.ffinfo["Forces"][self.name]["node"])) if + self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Proper"] + improper_node_indices = [i for i in range(len( + self.ffinfo["Forces"][self.name]["node"])) if + self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Improper"] + + proper_phase = paramset[self.name]["proper_phase"] + proper_k = paramset[self.name]["proper_k"] + proper_shift = paramset[self.name]["proper_shift"] + proper_msks = paramset.mask[self.name]["proper"] + for nnode, key in enumerate(self.proper_keys): + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]]["attrib"] = { + } + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}1"] = key[0] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}2"] = key[1] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}3"] = key[2] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}4"] = key[3] + shiftTem = 0 + for nitem, item in enumerate(self.proper_key_to_prms[nnode]): + phase, k, shift = proper_phase[item], proper_k[item], proper_shift[item] + mask = proper_msks[item] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["per" + str(nitem + 1)] = str(self.proper_periods[item]) + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["phase" + str(nitem + 1)] = str(phase) + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["k" + str(nitem + 1)] = str(k) + shiftTem += shift + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["shift"] = str(shiftTem) + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["mask"] = "true" + + improper_phase = paramset[self.name]["improper_phase"] + improper_k = paramset[self.name]["improper_k"] + improper_shift = paramset[self.name]["improper_shift"] + improper_msks = paramset.mask[self.name]["improper"] + for nnode, key in enumerate(self.imp_keys): + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]]["attrib"] = { + } + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}1"] = key[0] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}2"] = key[1] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}3"] = key[2] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}4"] = key[3] + shiftTem = 0 + for nitem, item in enumerate(self.imp_key_to_prms[nnode]): + phase = improper_phase[item] + k = improper_k[item] + shift = improper_shift[item] + mask = improper_msks[item] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["per" + str(nitem + 1)] = str(self.imp_periods[item]) + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["phase" + str(nitem + 1)] = str(phase) + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["k" + str(nitem + 1)] = str(k) + shiftTem += shift + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["shift"] = str(shiftTem) + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["mask"] = "true" + + def _find_proper_key_index(self, key: Tuple[str, str, str, str]) -> int: + wc_patch = [] + for i, k in enumerate(self.proper_keys): + if k[0] in ["", key[0]] and k[1] in ["", key[1]] and k[2] in ["", key[2]] and k[3] in ["", key[3]]: + if "" in k: + wc_patch.append(i) + else: + return i + if k[0] in ["", key[3]] and k[1] in ["", key[2]] and k[2] in ["", key[1]] and k[3] in ["", key[0]]: + if "" in k: + wc_patch.append(i) + else: + return i + if len(wc_patch) > 0: + return wc_patch[0] + return None + + def _find_improper_key_index(self, improper): + + type1 = improper[0].meta[self.key_type] + type2 = improper[1].meta[self.key_type] + type3 = improper[2].meta[self.key_type] + type4 = improper[3].meta[self.key_type] + + def _wild_match(tp, tps): + if tps == "": + return True + if tp == tps: + return True + return False + + matched = None + for ndef, tordef in enumerate(self.imp_keys): + types1 = tordef[0] + types2 = tordef[1] + types3 = tordef[2] + types4 = tordef[3] + hasWildcard = ("" in (types1, types2, types3, types4)) + + if matched is not None and hasWildcard: + continue + + import itertools + if type1 in types1: + for (t2, t3, t4) in itertools.permutations(((type2, 1), (type3, 2), (type4, 3))): + if _wild_match(t2[0], types2) and _wild_match(t3[0], types3) and _wild_match(t4[0], types4): + a1 = improper[t2[1]].index + a2 = improper[t3[1]].index + e1 = improper[t2[1]].element + e2 = improper[t3[1]].element + m1 = app.element.get_by_symbol(e1).mass + m2 = app.element.get_by_symbol(e2).mass + if e1 == e2 and a1 > a2: + (a1, a2) = (a2, a1) + elif e1 != "C" and (e2 == "C" or m1 < m2): + (a1, a2) = (a2, a1) + matched = (a1, a2, improper[0].index, improper[t4[1]].index, ndef) + break + if matched is None: + return None, None + return matched[4], matched[:4] + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, + nonbondedCutoff, **kwargs): + + if self.key_type is None: + def potential_fn_zero(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, + params: ParamSet) -> jnp.ndarray: + return jnp.zeros((1,)) + + self._jaxPotential = potential_fn_zero + return potential_fn_zero + + proper_list = [] + + acenters = {} + atoms = [a for a in topdata.atoms()] + for bond in topdata.bonds(): + a1, a2 = bond.atom1, bond.atom2 + i1, i2 = a1.index, a2.index + if i1 not in acenters: + acenters[i1] = [] + acenters[i1].append(i2) + if i2 not in acenters: + acenters[i2] = [] + acenters[i2].append(i1) + + # find rotamers and loop over proper torsions on the rotamer + for bond in topdata.bonds(): + a1, a2 = bond.atom1, bond.atom2 + i1, i2 = a1.index, a2.index + alinks1 = [i for i in acenters[i1] if i != i2] + alinks2 = [i for i in acenters[i2] if i != i1] + for i3 in alinks1: + for i4 in alinks2: + if i3 != i4: + proper_list.append( + (atoms[i3], atoms[i1], atoms[i2], atoms[i4])) + + impr_list = [] + # find atoms that link with three other atoms + import itertools as it + for i1 in acenters: + if len(acenters[i1]) < 3: + continue + for item in it.combinations(acenters[i1], 3): + impr_list.append( + (atoms[i1], atoms[item[0]], atoms[item[1]], atoms[item[2]])) + + # create potential + proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period = [ + ], [], [], [], [], [] + for proper in proper_list: + pidx = self._find_proper_key_index( + (proper[0].meta[self.key_type], proper[1].meta[self.key_type], proper[2].meta[self.key_type], + proper[3].meta[self.key_type])) + if pidx is None: + continue + + prm_indices = self.proper_key_to_prms[pidx] + for prm_idx in prm_indices: + prm_period = self.proper_periods[prm_idx] + proper_a1.append(proper[0].index) + proper_a2.append(proper[1].index) + proper_a3.append(proper[2].index) + proper_a4.append(proper[3].index) + proper_indices.append(prm_idx) + proper_period.append(prm_period) + + proper_a1 = jnp.array(proper_a1) + proper_a2 = jnp.array(proper_a2) + proper_a3 = jnp.array(proper_a3) + proper_a4 = jnp.array(proper_a4) + proper_indices = jnp.array(proper_indices) + proper_period = jnp.array(proper_period) + + improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period = [], [], [], [], [], [] + for improper in impr_list: + iidx, order = self._find_improper_key_index(improper) + if iidx is None: + continue + + prm_indices = self.imp_key_to_prms[iidx] + for prm_idx in prm_indices: + prm_period = self.imp_periods[prm_idx] + improper_a1.append(atoms[order[0]].index) + improper_a2.append(atoms[order[1]].index) + improper_a3.append(atoms[order[2]].index) + improper_a4.append(atoms[order[3]].index) + improper_indices.append(prm_idx) + improper_period.append(prm_period) + improper_a1 = jnp.array(improper_a1) + improper_a2 = jnp.array(improper_a2) + improper_a3 = jnp.array(improper_a3) + improper_a4 = jnp.array(improper_a4) + improper_indices = jnp.array(improper_indices) + improper_period = jnp.array(improper_period) + + proper_func = CustomTorsionJaxForce( + proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period) + proper_energy = proper_func.generate_get_energy() + improper_func = CustomTorsionJaxForce( + improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period) + improper_energy = improper_func.generate_get_energy() + + has_aux = False + if "has_aux" in kwargs and kwargs["has_aux"]: + has_aux = True + + def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None): + isinstance_jnp(positions, box, params) + proper_energy_ = proper_energy( + positions, box, pairs, params[self.name]["proper_k"], params[self.name]["proper_phase"], params[self.name]["proper_shift"]) + improper_energy_ = improper_energy( + positions, box, pairs, params[self.name]["improper_k"], params[self.name]["improper_phase"], params[self.name]["improper_shift"]) + if has_aux: + return proper_energy_ + improper_energy_, aux + else: + return proper_energy_ + improper_energy_ + + self._jaxPotential = potential_fn + return potential_fn + + +_DMFFGenerators["CustomTorsionForce"] = CustomTorsionGenerator From 29577626ecf2261cdaf04e8069f5276223f77e7c Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:21:38 +0800 Subject: [PATCH 05/15] Add files via upload --- dmff/generators/classical.py | 876 ++++++++++++++++++----------------- 1 file changed, 439 insertions(+), 437 deletions(-) diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index f2993ea9b..e4d247d19 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import openmm.app as app import openmm.unit as unit -from ..classical.intra import HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce, , Custom1_5BondJaxForce, CustomTorsionJaxForce +from ..classical.intra import HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce, Custom1_5BondJaxForce, CustomTorsionJaxForce from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce, CoulReactionFieldForce, LennardJonesForce, LennardJonesLongRangeForce, CustomGBForce from typing import Tuple, List, Union, Callable @@ -62,6 +62,8 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): raise ValueError( "Cannot find key type for HarmonicBondForce.") key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"]) + + bond_keys.append(key) k = float(attribs["k"]) @@ -787,98 +789,458 @@ def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, p _DMFFGenerators["PeriodicTorsionForce"] = PeriodicTorsionGenerator -class NonbondedGenerator: +class CustomTorsionGenerator: + def __init__(self, ffinfo: dict, paramset: ParamSet): - self.name = "NonbondedForce" + """ + Initializes a PeriodicTorsionForce object. + + Args: + - ffinfo (dict): A dictionary containing force field information. + - paramset (ParamSet): A ParamSet object to register parameters. + + Returns: + - None + """ + self.name = "CustomTorsionForce" self.ffinfo = ffinfo paramset.addField(self.name) - self.coulomb14scale = float( - self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("coulomb14scale", 0.8333333333333334)) - self.lj14scale = float( - self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("lj14scale", 0.5)) + self._use_smarts = False self.key_type = None - self.type_to_charge = {} - - self.charge_in_residue = False - for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]: - if not self.charge_in_residue and node["name"] == "UseAttributeFromResidue": - if node["attrib"]["name"] == "charge": - self.charge_in_residue = True - - types, sigma, epsilon, atom_mask = [], [], [], [] - for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]: - if node["name"] == "Atom": - attribs = node["attrib"] - self.key_type = None - if "type" in attribs: - self.key_type = "type" - elif "class" in attribs: - self.key_type = "class" - types.append(attribs[self.key_type]) - sigma.append(float(attribs["sigma"])) - epsilon.append(float(attribs["epsilon"])) - mask = 1.0 - if "mask" in attribs and attribs["mask"].upper() == "TRUE": - mask = 0.0 - atom_mask.append(mask) - if not self.charge_in_residue: - if "charge" not in attribs: - raise ValueError("No charge information found in NonbondedForce or Residues.") - self.type_to_charge[attribs[self.key_type]] = float(attribs["charge"]) + self.torsionIndices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if "roper" in self.ffinfo["Forces"][self.name]["node"][i]["name"]] - sigma = jnp.array(sigma) - epsilon = jnp.array(epsilon) - atom_mask = jnp.array(atom_mask) - self.atom_keys = types - paramset.addParameter(sigma, "sigma", field=self.name, mask=atom_mask) - paramset.addParameter(epsilon, "epsilon", field=self.name, mask=atom_mask) + proper_keys, proper_periods, proper_prms, proper_shift = [], [], [], [] + proper_key_to_prms = {} + improper_keys, improper_periods, improper_prms, improper_shift = [], [], [], [] + improper_key_to_prms = {} + for i in self.torsionIndices: + node = self.ffinfo["Forces"][self.name]["node"][i] + attribs = node["attrib"] + if "type1" in attribs: + self.key_type = "type" + elif "class1" in attribs: + self.key_type = "class" + key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"], + attribs[self.key_type + "3"], attribs[self.key_type + "4"]) + if node["name"] == "Proper": + proper_keys.append(key) + elif node["name"] == "Improper": + improper_keys.append(key) + + mask = 1.0 + if "mask" in attribs and attribs["mask"].upper() == "TRUE": + mask = 0.0 + + for period_key in attribs.keys(): + if "per" not in period_key: + continue + order = int(period_key.replace("per", "")) + period = int(attribs[period_key]) + phase = float(attribs["phase" + str(order)]) + k = float(attribs["k" + str(order)]) + shift = float(attribs["shift"])/4 + if node["name"] == "Proper": + proper_periods.append(period) + proper_prms.append([phase, k, mask, shift]) + if len(proper_keys) - 1 not in proper_key_to_prms: + proper_key_to_prms[len(proper_keys) - 1] = [] + proper_key_to_prms[len( + proper_keys) - 1].append(len(proper_periods) - 1) + elif node["name"] == "Improper": + improper_periods.append(period) + improper_prms.append([phase, k, mask, shift]) + if len(improper_keys) - 1 not in improper_key_to_prms: + improper_key_to_prms[len(improper_keys) - 1] = [] + improper_key_to_prms[len( + improper_keys) - 1].append(len(improper_periods) - 1) + + self.proper_keys = proper_keys + self.proper_periods = jnp.array(proper_periods) + self.proper_key_to_prms = proper_key_to_prms + proper_phase = jnp.array([i[0] for i in proper_prms]) + proper_k = jnp.array([i[1] for i in proper_prms]) + proper_mask = jnp.array([i[2] for i in proper_prms]) + proper_shift = jnp.array([i[3] for i in proper_prms]) + # register parameters to ParamSet + paramset.addParameter(proper_phase, "proper_phase", + field=self.name, mask=proper_mask) + paramset.addParameter(proper_k, "proper_k", + field=self.name, mask=proper_mask) + paramset.addParameter(proper_shift, "proper_shift", + field=self.name, mask=proper_mask) + + self.imp_keys = improper_keys + self.imp_periods = jnp.array(improper_periods) + self.imp_key_to_prms = improper_key_to_prms + improper_phase = jnp.array([i[0] for i in improper_prms]) + improper_k = jnp.array([i[1] for i in improper_prms]) + improper_mask = jnp.array([i[2] for i in improper_prms]) + improper_shift = jnp.array([i[3] for i in improper_prms]) + # register parameters to ParamSet + paramset.addParameter(improper_phase, "improper_phase", + field=self.name, mask=improper_mask) + paramset.addParameter(improper_k, "improper_k", + field=self.name, mask=improper_mask) + paramset.addParameter(improper_shift, "improper_shift", + field=self.name, mask=improper_mask) def getName(self): return self.name def overwrite(self, paramset): - sigma = paramset[self.name]["sigma"] - epsilon = paramset[self.name]["epsilon"] - atom_mask = paramset.mask[self.name]["sigma"] + # paramset to ffinfo + proper_node_indices = [i for i in range(len( + self.ffinfo["Forces"][self.name]["node"])) if + self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Proper"] + improper_node_indices = [i for i in range(len( + self.ffinfo["Forces"][self.name]["node"])) if + self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Improper"] - node2atom = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"] + proper_phase = paramset[self.name]["proper_phase"] + proper_k = paramset[self.name]["proper_k"] + proper_shift = paramset[self.name]["proper_shift"] + proper_msks = paramset.mask[self.name]["proper"] + for nnode, key in enumerate(self.proper_keys): + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]]["attrib"] = { + } + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}1"] = key[0] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}2"] = key[1] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}3"] = key[2] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"][f"{self.key_type}4"] = key[3] + shiftTem = 0 + for nitem, item in enumerate(self.proper_key_to_prms[nnode]): + phase, k, shift = proper_phase[item], proper_k[item], proper_shift[item] + mask = proper_msks[item] + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["per" + str(nitem + 1)] = str(self.proper_periods[item]) + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["phase" + str(nitem + 1)] = str(phase) + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["k" + str(nitem + 1)] = str(k) + shiftTem += shift + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["shift"] = str(shiftTem) + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] + ]["attrib"]["mask"] = "true" - for natom in range(len(self.atom_keys)): - nnode = node2atom[natom] - sig_new = sigma[natom] - eps_new = epsilon[natom] - mask = atom_mask[natom] - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = str(sig_new) - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = str(eps_new) + improper_phase = paramset[self.name]["improper_phase"] + improper_k = paramset[self.name]["improper_k"] + improper_shift = paramset[self.name]["improper_shift"] + improper_msks = paramset.mask[self.name]["improper"] + for nnode, key in enumerate(self.imp_keys): + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]]["attrib"] = { + } + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}1"] = key[0] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}2"] = key[1] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}3"] = key[2] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"][f"{self.key_type}4"] = key[3] + shiftTem = 0 + for nitem, item in enumerate(self.imp_key_to_prms[nnode]): + phase = improper_phase[item] + k = improper_k[item] + shift = improper_shift[item] + mask = improper_msks[item] + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["per" + str(nitem + 1)] = str(self.imp_periods[item]) + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["phase" + str(nitem + 1)] = str(phase) + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["k" + str(nitem + 1)] = str(k) + shiftTem += shift + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["shift"] = str(shiftTem) if mask < 0.999: - self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true" + self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] + ]["attrib"]["mask"] = "true" - def _find_atype_key_index(self, atype: str): - for n, i in enumerate(self.atom_keys): - if i == atype: - return n + def _find_proper_key_index(self, key: Tuple[str, str, str, str]) -> int: + wc_patch = [] + for i, k in enumerate(self.proper_keys): + if k[0] in ["", key[0]] and k[1] in ["", key[1]] and k[2] in ["", key[2]] and k[3] in ["", key[3]]: + if "" in k: + wc_patch.append(i) + else: + return i + if k[0] in ["", key[3]] and k[1] in ["", key[2]] and k[2] in ["", key[1]] and k[3] in ["", key[0]]: + if "" in k: + wc_patch.append(i) + else: + return i + if len(wc_patch) > 0: + return wc_patch[0] return None - - def createPotential(self, topdata: DMFFTopology, nonbondedMethod, - nonbondedCutoff, **kwargs): - methodMap = { - app.NoCutoff: "NoCutoff", - app.CutoffPeriodic: "CutoffPeriodic", - app.CutoffNonPeriodic: "CutoffNonPeriodic", - app.PME: "PME", - } - methodString = methodMap[nonbondedMethod] - if nonbondedMethod not in methodMap: - raise DMFFException("Illegal nonbonded method for NonbondedForce") - isNoCut = False - if nonbondedMethod is app.NoCutoff: - isNoCut = True + def _find_improper_key_index(self, improper): - mscales_coul = jnp.array([0.0, 0.0, self.coulomb14scale, 1.0, 1.0, - 1.0]) - mscales_lj = jnp.array([0.0, 0.0, self.lj14scale, 1.0, 1.0, - 1.0]) + type1 = improper[0].meta[self.key_type] + type2 = improper[1].meta[self.key_type] + type3 = improper[2].meta[self.key_type] + type4 = improper[3].meta[self.key_type] + + def _wild_match(tp, tps): + if tps == "": + return True + if tp == tps: + return True + return False + + matched = None + for ndef, tordef in enumerate(self.imp_keys): + types1 = tordef[0] + types2 = tordef[1] + types3 = tordef[2] + types4 = tordef[3] + hasWildcard = ("" in (types1, types2, types3, types4)) + + if matched is not None and hasWildcard: + continue + + import itertools + if type1 in types1: + for (t2, t3, t4) in itertools.permutations(((type2, 1), (type3, 2), (type4, 3))): + if _wild_match(t2[0], types2) and _wild_match(t3[0], types3) and _wild_match(t4[0], types4): + a1 = improper[t2[1]].index + a2 = improper[t3[1]].index + e1 = improper[t2[1]].element + e2 = improper[t3[1]].element + m1 = app.element.get_by_symbol(e1).mass + m2 = app.element.get_by_symbol(e2).mass + if e1 == e2 and a1 > a2: + (a1, a2) = (a2, a1) + elif e1 != "C" and (e2 == "C" or m1 < m2): + (a1, a2) = (a2, a1) + matched = (a1, a2, improper[0].index, improper[t4[1]].index, ndef) + break + if matched is None: + return None, None + return matched[4], matched[:4] + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, + nonbondedCutoff, **kwargs): + + if self.key_type is None: + def potential_fn_zero(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, + params: ParamSet) -> jnp.ndarray: + return jnp.zeros((1,)) + + self._jaxPotential = potential_fn_zero + return potential_fn_zero + + proper_list = [] + + acenters = {} + atoms = [a for a in topdata.atoms()] + for bond in topdata.bonds(): + a1, a2 = bond.atom1, bond.atom2 + i1, i2 = a1.index, a2.index + if i1 not in acenters: + acenters[i1] = [] + acenters[i1].append(i2) + if i2 not in acenters: + acenters[i2] = [] + acenters[i2].append(i1) + + # find rotamers and loop over proper torsions on the rotamer + for bond in topdata.bonds(): + a1, a2 = bond.atom1, bond.atom2 + i1, i2 = a1.index, a2.index + alinks1 = [i for i in acenters[i1] if i != i2] + alinks2 = [i for i in acenters[i2] if i != i1] + for i3 in alinks1: + for i4 in alinks2: + if i3 != i4: + proper_list.append( + (atoms[i3], atoms[i1], atoms[i2], atoms[i4])) + + impr_list = [] + # find atoms that link with three other atoms + import itertools as it + for i1 in acenters: + if len(acenters[i1]) < 3: + continue + for item in it.combinations(acenters[i1], 3): + impr_list.append( + (atoms[i1], atoms[item[0]], atoms[item[1]], atoms[item[2]])) + + # create potential + proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period = [ + ], [], [], [], [], [] + for proper in proper_list: + pidx = self._find_proper_key_index( + (proper[0].meta[self.key_type], proper[1].meta[self.key_type], proper[2].meta[self.key_type], + proper[3].meta[self.key_type])) + if pidx is None: + continue + + prm_indices = self.proper_key_to_prms[pidx] + for prm_idx in prm_indices: + prm_period = self.proper_periods[prm_idx] + proper_a1.append(proper[0].index) + proper_a2.append(proper[1].index) + proper_a3.append(proper[2].index) + proper_a4.append(proper[3].index) + proper_indices.append(prm_idx) + proper_period.append(prm_period) + + proper_a1 = jnp.array(proper_a1) + proper_a2 = jnp.array(proper_a2) + proper_a3 = jnp.array(proper_a3) + proper_a4 = jnp.array(proper_a4) + proper_indices = jnp.array(proper_indices) + proper_period = jnp.array(proper_period) + + improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period = [], [], [], [], [], [] + for improper in impr_list: + iidx, order = self._find_improper_key_index(improper) + if iidx is None: + continue + + prm_indices = self.imp_key_to_prms[iidx] + for prm_idx in prm_indices: + prm_period = self.imp_periods[prm_idx] + improper_a1.append(atoms[order[0]].index) + improper_a2.append(atoms[order[1]].index) + improper_a3.append(atoms[order[2]].index) + improper_a4.append(atoms[order[3]].index) + improper_indices.append(prm_idx) + improper_period.append(prm_period) + improper_a1 = jnp.array(improper_a1) + improper_a2 = jnp.array(improper_a2) + improper_a3 = jnp.array(improper_a3) + improper_a4 = jnp.array(improper_a4) + improper_indices = jnp.array(improper_indices) + improper_period = jnp.array(improper_period) + + proper_func = CustomTorsionJaxForce( + proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period) + proper_energy = proper_func.generate_get_energy() + improper_func = CustomTorsionJaxForce( + improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period) + improper_energy = improper_func.generate_get_energy() + + has_aux = False + if "has_aux" in kwargs and kwargs["has_aux"]: + has_aux = True + + def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None): + isinstance_jnp(positions, box, params) + proper_energy_ = proper_energy( + positions, box, pairs, params[self.name]["proper_k"], params[self.name]["proper_phase"], params[self.name]["proper_shift"]) + improper_energy_ = improper_energy( + positions, box, pairs, params[self.name]["improper_k"], params[self.name]["improper_phase"], params[self.name]["improper_shift"]) + if has_aux: + return proper_energy_ + improper_energy_, aux + else: + return proper_energy_ + improper_energy_ + + self._jaxPotential = potential_fn + return potential_fn + + +_DMFFGenerators["CustomTorsionForce"] = CustomTorsionGenerator + + +class NonbondedGenerator: + def __init__(self, ffinfo: dict, paramset: ParamSet): + self.name = "NonbondedForce" + self.ffinfo = ffinfo + paramset.addField(self.name) + self.coulomb14scale = float( + self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("coulomb14scale", 0.8333333333333334)) + self.lj14scale = float( + self.ffinfo["Forces"]["NonbondedForce"]["meta"].get("lj14scale", 0.5)) + self.key_type = None + self.type_to_charge = {} + + self.charge_in_residue = False + for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]: + if not self.charge_in_residue and node["name"] == "UseAttributeFromResidue": + if node["attrib"]["name"] == "charge": + self.charge_in_residue = True + + types, sigma, epsilon, atom_mask = [], [], [], [] + for node in self.ffinfo["Forces"]["NonbondedForce"]["node"]: + if node["name"] == "Atom": + attribs = node["attrib"] + self.key_type = None + if "type" in attribs: + self.key_type = "type" + elif "class" in attribs: + self.key_type = "class" + types.append(attribs[self.key_type]) + sigma.append(float(attribs["sigma"])) + epsilon.append(float(attribs["epsilon"])) + mask = 1.0 + if "mask" in attribs and attribs["mask"].upper() == "TRUE": + mask = 0.0 + atom_mask.append(mask) + if not self.charge_in_residue: + if "charge" not in attribs: + raise ValueError("No charge information found in NonbondedForce or Residues.") + self.type_to_charge[attribs[self.key_type]] = float(attribs["charge"]) + + sigma = jnp.array(sigma) + epsilon = jnp.array(epsilon) + atom_mask = jnp.array(atom_mask) + self.atom_keys = types + paramset.addParameter(sigma, "sigma", field=self.name, mask=atom_mask) + paramset.addParameter(epsilon, "epsilon", field=self.name, mask=atom_mask) + + def getName(self): + return self.name + + def overwrite(self, paramset): + sigma = paramset[self.name]["sigma"] + epsilon = paramset[self.name]["epsilon"] + atom_mask = paramset.mask[self.name]["sigma"] + + node2atom = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Atom"] + + for natom in range(len(self.atom_keys)): + nnode = node2atom[natom] + sig_new = sigma[natom] + eps_new = epsilon[natom] + mask = atom_mask[natom] + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["sigma"] = str(sig_new) + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["epsilon"] = str(eps_new) + if mask < 0.999: + self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["mask"] = "true" + + def _find_atype_key_index(self, atype: str): + for n, i in enumerate(self.atom_keys): + if i == atype: + return n + return None + + def createPotential(self, topdata: DMFFTopology, nonbondedMethod, + nonbondedCutoff, **kwargs): + methodMap = { + app.NoCutoff: "NoCutoff", + app.CutoffPeriodic: "CutoffPeriodic", + app.CutoffNonPeriodic: "CutoffNonPeriodic", + app.PME: "PME", + } + methodString = methodMap[nonbondedMethod] + if nonbondedMethod not in methodMap: + raise DMFFException("Illegal nonbonded method for NonbondedForce") + + isNoCut = False + if nonbondedMethod is app.NoCutoff: + isNoCut = True + + mscales_coul = jnp.array([0.0, 0.0, self.coulomb14scale, 1.0, 1.0, + 1.0]) + mscales_lj = jnp.array([0.0, 0.0, self.lj14scale, 1.0, 1.0, + 1.0]) # coulomb # set PBC @@ -1747,366 +2109,6 @@ def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, p self._jaxPotential = potential_fn return potential_fn - - -_DMFFGenerators["CustomGBForce"] = CustomGBGenerator - - -class CustomTorsionGenerator: - - def __init__(self, ffinfo: dict, paramset: ParamSet): - """ - Initializes a PeriodicTorsionForce object. - - Args: - - ffinfo (dict): A dictionary containing force field information. - - paramset (ParamSet): A ParamSet object to register parameters. - - Returns: - - None - """ - self.name = "CustomTorsionForce" - self.ffinfo = ffinfo - paramset.addField(self.name) - self._use_smarts = False - self.key_type = None - self.torsionIndices = [i for i in range(len(self.ffinfo["Forces"][self.name]["node"])) if "roper" in self.ffinfo["Forces"][self.name]["node"][i]["name"]] - - proper_keys, proper_periods, proper_prms, proper_shift = [], [], [], [] - proper_key_to_prms = {} - improper_keys, improper_periods, improper_prms, improper_shift = [], [], [], [] - improper_key_to_prms = {} - for i in self.torsionIndices: - node = self.ffinfo["Forces"][self.name]["node"][i] - attribs = node["attrib"] - if "type1" in attribs: - self.key_type = "type" - elif "class1" in attribs: - self.key_type = "class" - key = (attribs[self.key_type + "1"], attribs[self.key_type + "2"], - attribs[self.key_type + "3"], attribs[self.key_type + "4"]) - if node["name"] == "Proper": - proper_keys.append(key) - elif node["name"] == "Improper": - improper_keys.append(key) - - mask = 1.0 - if "mask" in attribs and attribs["mask"].upper() == "TRUE": - mask = 0.0 - - for period_key in attribs.keys(): - if "per" not in period_key: - continue - order = int(period_key.replace("per", "")) - period = int(attribs[period_key]) - phase = float(attribs["phase" + str(order)]) - k = float(attribs["k" + str(order)]) - shift = float(attribs["shift"])/4 - if node["name"] == "Proper": - proper_periods.append(period) - proper_prms.append([phase, k, mask, shift]) - if len(proper_keys) - 1 not in proper_key_to_prms: - proper_key_to_prms[len(proper_keys) - 1] = [] - proper_key_to_prms[len( - proper_keys) - 1].append(len(proper_periods) - 1) - elif node["name"] == "Improper": - improper_periods.append(period) - improper_prms.append([phase, k, mask, shift]) - if len(improper_keys) - 1 not in improper_key_to_prms: - improper_key_to_prms[len(improper_keys) - 1] = [] - improper_key_to_prms[len( - improper_keys) - 1].append(len(improper_periods) - 1) - - self.proper_keys = proper_keys - self.proper_periods = jnp.array(proper_periods) - self.proper_key_to_prms = proper_key_to_prms - proper_phase = jnp.array([i[0] for i in proper_prms]) - proper_k = jnp.array([i[1] for i in proper_prms]) - proper_mask = jnp.array([i[2] for i in proper_prms]) - proper_shift = jnp.array([i[3] for i in proper_prms]) - # register parameters to ParamSet - paramset.addParameter(proper_phase, "proper_phase", - field=self.name, mask=proper_mask) - paramset.addParameter(proper_k, "proper_k", - field=self.name, mask=proper_mask) - paramset.addParameter(proper_shift, "proper_shift", - field=self.name, mask=proper_mask) - - self.imp_keys = improper_keys - self.imp_periods = jnp.array(improper_periods) - self.imp_key_to_prms = improper_key_to_prms - improper_phase = jnp.array([i[0] for i in improper_prms]) - improper_k = jnp.array([i[1] for i in improper_prms]) - improper_mask = jnp.array([i[2] for i in improper_prms]) - improper_shift = jnp.array([i[3] for i in improper_prms]) - # register parameters to ParamSet - paramset.addParameter(improper_phase, "improper_phase", - field=self.name, mask=improper_mask) - paramset.addParameter(improper_k, "improper_k", - field=self.name, mask=improper_mask) - paramset.addParameter(improper_shift, "improper_shift", - field=self.name, mask=improper_mask) - - def getName(self): - return self.name - - def overwrite(self, paramset): - # paramset to ffinfo - proper_node_indices = [i for i in range(len( - self.ffinfo["Forces"][self.name]["node"])) if - self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Proper"] - improper_node_indices = [i for i in range(len( - self.ffinfo["Forces"][self.name]["node"])) if - self.ffinfo["Forces"][self.name]["node"][i]["name"] == "Improper"] - - proper_phase = paramset[self.name]["proper_phase"] - proper_k = paramset[self.name]["proper_k"] - proper_shift = paramset[self.name]["proper_shift"] - proper_msks = paramset.mask[self.name]["proper"] - for nnode, key in enumerate(self.proper_keys): - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode]]["attrib"] = { - } - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"][f"{self.key_type}1"] = key[0] - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"][f"{self.key_type}2"] = key[1] - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"][f"{self.key_type}3"] = key[2] - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"][f"{self.key_type}4"] = key[3] - shiftTem = 0 - for nitem, item in enumerate(self.proper_key_to_prms[nnode]): - phase, k, shift = proper_phase[item], proper_k[item], proper_shift[item] - mask = proper_msks[item] - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"]["per" + str(nitem + 1)] = str(self.proper_periods[item]) - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"]["phase" + str(nitem + 1)] = str(phase) - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"]["k" + str(nitem + 1)] = str(k) - shiftTem += shift - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"]["shift"] = str(shiftTem) - if mask < 0.999: - self.ffinfo["Forces"][self.name]["node"][proper_node_indices[nnode] - ]["attrib"]["mask"] = "true" - - improper_phase = paramset[self.name]["improper_phase"] - improper_k = paramset[self.name]["improper_k"] - improper_shift = paramset[self.name]["improper_shift"] - improper_msks = paramset.mask[self.name]["improper"] - for nnode, key in enumerate(self.imp_keys): - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode]]["attrib"] = { - } - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"][f"{self.key_type}1"] = key[0] - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"][f"{self.key_type}2"] = key[1] - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"][f"{self.key_type}3"] = key[2] - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"][f"{self.key_type}4"] = key[3] - shiftTem = 0 - for nitem, item in enumerate(self.imp_key_to_prms[nnode]): - phase = improper_phase[item] - k = improper_k[item] - shift = improper_shift[item] - mask = improper_msks[item] - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"]["per" + str(nitem + 1)] = str(self.imp_periods[item]) - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"]["phase" + str(nitem + 1)] = str(phase) - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"]["k" + str(nitem + 1)] = str(k) - shiftTem += shift - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"]["shift"] = str(shiftTem) - if mask < 0.999: - self.ffinfo["Forces"][self.name]["node"][improper_node_indices[nnode] - ]["attrib"]["mask"] = "true" - - def _find_proper_key_index(self, key: Tuple[str, str, str, str]) -> int: - wc_patch = [] - for i, k in enumerate(self.proper_keys): - if k[0] in ["", key[0]] and k[1] in ["", key[1]] and k[2] in ["", key[2]] and k[3] in ["", key[3]]: - if "" in k: - wc_patch.append(i) - else: - return i - if k[0] in ["", key[3]] and k[1] in ["", key[2]] and k[2] in ["", key[1]] and k[3] in ["", key[0]]: - if "" in k: - wc_patch.append(i) - else: - return i - if len(wc_patch) > 0: - return wc_patch[0] - return None - - def _find_improper_key_index(self, improper): - - type1 = improper[0].meta[self.key_type] - type2 = improper[1].meta[self.key_type] - type3 = improper[2].meta[self.key_type] - type4 = improper[3].meta[self.key_type] - - def _wild_match(tp, tps): - if tps == "": - return True - if tp == tps: - return True - return False - - matched = None - for ndef, tordef in enumerate(self.imp_keys): - types1 = tordef[0] - types2 = tordef[1] - types3 = tordef[2] - types4 = tordef[3] - hasWildcard = ("" in (types1, types2, types3, types4)) - - if matched is not None and hasWildcard: - continue - - import itertools - if type1 in types1: - for (t2, t3, t4) in itertools.permutations(((type2, 1), (type3, 2), (type4, 3))): - if _wild_match(t2[0], types2) and _wild_match(t3[0], types3) and _wild_match(t4[0], types4): - a1 = improper[t2[1]].index - a2 = improper[t3[1]].index - e1 = improper[t2[1]].element - e2 = improper[t3[1]].element - m1 = app.element.get_by_symbol(e1).mass - m2 = app.element.get_by_symbol(e2).mass - if e1 == e2 and a1 > a2: - (a1, a2) = (a2, a1) - elif e1 != "C" and (e2 == "C" or m1 < m2): - (a1, a2) = (a2, a1) - matched = (a1, a2, improper[0].index, improper[t4[1]].index, ndef) - break - if matched is None: - return None, None - return matched[4], matched[:4] - - def createPotential(self, topdata: DMFFTopology, nonbondedMethod, - nonbondedCutoff, **kwargs): - - if self.key_type is None: - def potential_fn_zero(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, - params: ParamSet) -> jnp.ndarray: - return jnp.zeros((1,)) - - self._jaxPotential = potential_fn_zero - return potential_fn_zero - - proper_list = [] - - acenters = {} - atoms = [a for a in topdata.atoms()] - for bond in topdata.bonds(): - a1, a2 = bond.atom1, bond.atom2 - i1, i2 = a1.index, a2.index - if i1 not in acenters: - acenters[i1] = [] - acenters[i1].append(i2) - if i2 not in acenters: - acenters[i2] = [] - acenters[i2].append(i1) - - # find rotamers and loop over proper torsions on the rotamer - for bond in topdata.bonds(): - a1, a2 = bond.atom1, bond.atom2 - i1, i2 = a1.index, a2.index - alinks1 = [i for i in acenters[i1] if i != i2] - alinks2 = [i for i in acenters[i2] if i != i1] - for i3 in alinks1: - for i4 in alinks2: - if i3 != i4: - proper_list.append( - (atoms[i3], atoms[i1], atoms[i2], atoms[i4])) - - impr_list = [] - # find atoms that link with three other atoms - import itertools as it - for i1 in acenters: - if len(acenters[i1]) < 3: - continue - for item in it.combinations(acenters[i1], 3): - impr_list.append( - (atoms[i1], atoms[item[0]], atoms[item[1]], atoms[item[2]])) - - # create potential - proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period = [ - ], [], [], [], [], [] - for proper in proper_list: - pidx = self._find_proper_key_index( - (proper[0].meta[self.key_type], proper[1].meta[self.key_type], proper[2].meta[self.key_type], - proper[3].meta[self.key_type])) - if pidx is None: - continue - - prm_indices = self.proper_key_to_prms[pidx] - for prm_idx in prm_indices: - prm_period = self.proper_periods[prm_idx] - proper_a1.append(proper[0].index) - proper_a2.append(proper[1].index) - proper_a3.append(proper[2].index) - proper_a4.append(proper[3].index) - proper_indices.append(prm_idx) - proper_period.append(prm_period) - - proper_a1 = jnp.array(proper_a1) - proper_a2 = jnp.array(proper_a2) - proper_a3 = jnp.array(proper_a3) - proper_a4 = jnp.array(proper_a4) - proper_indices = jnp.array(proper_indices) - proper_period = jnp.array(proper_period) - - improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period = [], [], [], [], [], [] - for improper in impr_list: - iidx, order = self._find_improper_key_index(improper) - if iidx is None: - continue - prm_indices = self.imp_key_to_prms[iidx] - for prm_idx in prm_indices: - prm_period = self.imp_periods[prm_idx] - improper_a1.append(atoms[order[0]].index) - improper_a2.append(atoms[order[1]].index) - improper_a3.append(atoms[order[2]].index) - improper_a4.append(atoms[order[3]].index) - improper_indices.append(prm_idx) - improper_period.append(prm_period) - improper_a1 = jnp.array(improper_a1) - improper_a2 = jnp.array(improper_a2) - improper_a3 = jnp.array(improper_a3) - improper_a4 = jnp.array(improper_a4) - improper_indices = jnp.array(improper_indices) - improper_period = jnp.array(improper_period) - proper_func = CustomTorsionJaxForce( - proper_a1, proper_a2, proper_a3, proper_a4, proper_indices, proper_period) - proper_energy = proper_func.generate_get_energy() - improper_func = CustomTorsionJaxForce( - improper_a1, improper_a2, improper_a3, improper_a4, improper_indices, improper_period) - improper_energy = improper_func.generate_get_energy() - - has_aux = False - if "has_aux" in kwargs and kwargs["has_aux"]: - has_aux = True - - def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None): - isinstance_jnp(positions, box, params) - proper_energy_ = proper_energy( - positions, box, pairs, params[self.name]["proper_k"], params[self.name]["proper_phase"], params[self.name]["proper_shift"]) - improper_energy_ = improper_energy( - positions, box, pairs, params[self.name]["improper_k"], params[self.name]["improper_phase"], params[self.name]["improper_shift"]) - if has_aux: - return proper_energy_ + improper_energy_, aux - else: - return proper_energy_ + improper_energy_ - - self._jaxPotential = potential_fn - return potential_fn - - -_DMFFGenerators["CustomTorsionForce"] = CustomTorsionGenerator +_DMFFGenerators["CustomGBForce"] = CustomGBGenerator \ No newline at end of file From d853e41145d3860f4dd097aae1d4d164756576dc Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 14:22:46 +0800 Subject: [PATCH 06/15] Add files via upload --- dmff/classical/inter.py | 25 ++++++++++++++++++++++--- dmff/classical/intra.py | 2 +- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index f3229d02f..3e15a75f2 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -1,8 +1,9 @@ from typing import Iterable, Tuple, Optional import jax import jax.numpy as jnp +import jax.config as config +config.update("jax_enable_x64", True) import numpy as np - from ..utils import pair_buffer_scales, regularize_pairs from ..admp.pme import energy_pme from ..admp.recip import generate_pme_recip @@ -362,16 +363,25 @@ def __init__( self.eps_1 = epsilon_1 def generate_get_energy(self): - @jax.jit + # @jax.jit def get_energy(positions, box, pairs, Ipairs, charges, radius, scales): def calI(posList, radMap, scalMap, rhoMap, pairMap): + # posList [numOfAtoms, 3] + # radMap [numOfAtoms] + # pair1 = pairMap[:, 0] + # pair2 = pairMap[:, 1] I = jnp.array([]) + for i in range(len(radMap)): + # posj = posList[jnp.append(pair2[jnp.where(pair1 == i)], (pair1[jnp.where(pair2 == i)]))] + # rhoj = rhoMap[jnp.append(pair2[jnp.where(pair1 == i)], (pair1[jnp.where(pair2 == i)]))] + # scalj = scalMap[jnp.append(pair2[jnp.where(pair1 == i)], (pair1[jnp.where(pair2 == i)]))] posj = posList[Ipairs[i]] rhoj = rhoMap[Ipairs[i]] scalj = scalMap[Ipairs[i]] posi = posList[i] rhoi = rhoMap[i] + r = jnp.sqrt(jnp.sum(jnp.power(posi-posj,2),axis=1)) sr2 = rhoj * scalj D = jnp.abs(r - sr2) @@ -393,10 +403,19 @@ def calI(posList, radMap, scalMap, rhoMap, pairMap): IList = calI(positions, radiusMap, scalesMap, rhoMap, Ipairs) psi = IList*rhoMap rEff = 1/(1/rhoMap-jnp.tanh(self.alpha*psi-self.beta*jnp.power(psi, 2)+self.gamma*jnp.power(psi, 3))/radiusMap) + # surface area term energy Ese = jnp.sum(28.3919551*(radiusMap+0.14)**2*jnp.power(radiusMap/rEff, 6)-0.5*138.935456*(1/self.eps_1-1/self.exp_solv)*chargeMap**2/rEff) + # generalized born term energy + # distance calculated from atom pair [i,j] where i Date: Thu, 28 Dec 2023 14:23:44 +0800 Subject: [PATCH 07/15] Update inter.py --- dmff/classical/inter.py | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index 3e15a75f2..26cd9885a 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -336,12 +336,6 @@ def get_energy(positions, box, pairs, bcc, mscales): class CustomGBForce: - # E_{GB} = -\frac12(\frac1{\epsilon_{solute}}-\frac1{\epsilon_{solvent}})\sum_{i, j}\frac{q_iq_j}{f_{GB}(d_{ij}, R_i, R_j)} - # f_{GB}(d_{ij}, R_i, R_j)=[d_{ij} ^ 2 + R_iR_jexp(\frac{-d_{ij} ^ 2}{4R_iR_j})] ^ {1 / 2} - # R_i=\frac1{\rho_i^{-1}-r_i^{-1}tanh(\alpha\Psi_i-\beta\Psi_i^2+\gamma\Psi_i^3)} - # \alpha=1,\beta=0.8,\gamma=4.85,\rho_i=r_i-0.009nm - # \Psi_i=\frac{\rho_i}{4\pi}\int_{VDW}\theta(|r|-\rho_i)\frac1{|r|^4}d^3r - # E_{SAT}=E_{SA}\cdot4\pi\sum_i(r_i+r_{solvent})^2(\frac{r_i}{R_i})^6 def __init__( self, map_charge, @@ -366,22 +360,13 @@ def generate_get_energy(self): # @jax.jit def get_energy(positions, box, pairs, Ipairs, charges, radius, scales): def calI(posList, radMap, scalMap, rhoMap, pairMap): - # posList [numOfAtoms, 3] - # radMap [numOfAtoms] - # pair1 = pairMap[:, 0] - # pair2 = pairMap[:, 1] I = jnp.array([]) - for i in range(len(radMap)): - # posj = posList[jnp.append(pair2[jnp.where(pair1 == i)], (pair1[jnp.where(pair2 == i)]))] - # rhoj = rhoMap[jnp.append(pair2[jnp.where(pair1 == i)], (pair1[jnp.where(pair2 == i)]))] - # scalj = scalMap[jnp.append(pair2[jnp.where(pair1 == i)], (pair1[jnp.where(pair2 == i)]))] posj = posList[Ipairs[i]] rhoj = rhoMap[Ipairs[i]] scalj = scalMap[Ipairs[i]] posi = posList[i] rhoi = rhoMap[i] - r = jnp.sqrt(jnp.sum(jnp.power(posi-posj,2),axis=1)) sr2 = rhoj * scalj D = jnp.abs(r - sr2) @@ -405,17 +390,9 @@ def calI(posList, radMap, scalMap, rhoMap, pairMap): rEff = 1/(1/rhoMap-jnp.tanh(self.alpha*psi-self.beta*jnp.power(psi, 2)+self.gamma*jnp.power(psi, 3))/radiusMap) # surface area term energy Ese = jnp.sum(28.3919551*(radiusMap+0.14)**2*jnp.power(radiusMap/rEff, 6)-0.5*138.935456*(1/self.eps_1-1/self.exp_solv)*chargeMap**2/rEff) - # generalized born term energy - # distance calculated from atom pair [i,j] where i Date: Thu, 28 Dec 2023 15:39:56 +0800 Subject: [PATCH 08/15] Add files via upload --- docs/user_guide/4.1classical.md | 225 +++++++++++++++++++++++++++++--- 1 file changed, 209 insertions(+), 16 deletions(-) diff --git a/docs/user_guide/4.1classical.md b/docs/user_guide/4.1classical.md index 1dabe12e6..70185cf55 100644 --- a/docs/user_guide/4.1classical.md +++ b/docs/user_guide/4.1classical.md @@ -4,7 +4,9 @@ The chemical bond is represented by a harmonic potential: -$$ E = \frac{1}{2}k(b-b_0)^2 $$ +$$ +E = \frac{1}{2}k(b-b_0)^2 +$$ where $k$ is the force constant, $b$ is the distance betweeen two particles that forming a bond and $b_0$ is the equilibrium bond length. Note that in some other MD softwares, the potential form adopts a slight different form: $E=k(b-b_0)^2$. Users should check which form to use and multiply (or divide) the force constant by 2. @@ -21,18 +23,19 @@ The way to specify a harmonic bond in DMFF is the same as the way doing it in Op ``` -Every `` tag defines a rule for creating harmonic bond interactions between atoms. Each tag may identify the atoms either by type (using the attributes `type1` and `type2`) or by class (using the attributes `class1` and `class2`). `length` is the equilibrium bond length in $\mathrm{nm}$, and `k` is the force constant in $\mathrm{kJ/mol/nm^2}$. +Every `` tag defines a rule for creating harmonic bond interactions between atoms. Each tag may identify the atoms either by type (using the attributes `type1` and `type2`) or by class (using the attributes `class1` and `class2`). `length` is the equilibrium bond length in $\mathrm{nm}$, and `k` is the force constant in $\mathrm{kJ/mol/nm^2}$. When the tag has an attribute named `mask` and it's value set to `true`, this means the parameter is not trainable. Such information will be passed to `ParamSet.mask` (the corresponding mask value will be 0.0 if not trainable). - # HarmonicAngleJaxForce ## 1. Theory The angle is represented by a harmonic potential: -$$ E = \frac{1}{2}k(\theta - \theta_0)^2 $$ +$$ +E = \frac{1}{2}k(\theta - \theta_0)^2 +$$ where $k$ is the force constant, $\theta$ is the angle between three particles and $\theta_0$ is the equilibrium angle value. Similiar to `HarmonicBondJaxForce`, the parameters in some other MD softwares are defined to follow the potential form: $E=k(\theta-\theta_0)^2$. Be aware to adjust the parameters properly when applying the DMFF parameters to other softwares. @@ -48,6 +51,7 @@ The way to specify a harmonic angle in DMFF is the same as the way doing it in O ... ``` + Every `` tag defines a rule for creating harmonic angle interactions between triplets of atoms. Each tag may identify the atoms either by type (using the attributes `type1`, `type2`, `type3`) or by class (using the attributes `class1`, `class2`, `class3`). The force field identifies every set of three atoms in the system where the first is bonded to the second, and the second to the third. `angle` is the equilibrium angle in radians, and `k` is the spring constant in $\mathrm{kJ/mol/radian^2}$. When the tag has an attribute named `mask` and it's value set to `true`, this means the parameter is not trainable. Such information will be passed to `ParamSet.mask` (the corresponding mask value will be 0.0 if not trainable). @@ -58,7 +62,9 @@ When the tag has an attribute named `mask` and it's value set to `true`, this me The torsion is represented by a truncated periodic Fourier series: -$$ E = \sum_{n=0}^{6} k_n(1 + \cos(n\phi-\phi_{0n})) $$ +$$ +E = \sum_{n=0}^{6} k_n(1 + \cos(n\phi-\phi_{0n})) +$$ where $\phi$ is the dihedral angle formed by four particles, $n$ is the periodicity, $\phi_{0n}$ is the phase offset $k_{n}$ is the force constant. To perserve the symmetry, $\phi_{0n}$ usually adopts a value of $0$ (for $n=1,3,5$) or $\pi$ ($n=2,4,6$), and it is recommened to follow these definitions and not to optimize them in force field development. @@ -76,9 +82,10 @@ The way to specify a periodic torsion in DMFF is the same as the way doing it in ... ``` + Every child tag defines a rule for creating periodic torsion interactions between sets of four atoms. Each tag may identify the atoms either by type (using the attributes `type1`, `type2`, ...) or by class (using the attributes `class1`, `class2`, ...). -The force field recognizes two different types of torsions: `Proper` and `Improper`. A proper torsion involves four atoms that are bonded in sequence: 1 to 2, 2 to 3, and 3 to 4. An improper torsion involves a central atom and three others that are bonded to it: atoms 2, 3, and 4 are all bonded to atom 1. `periodicity1` is the periodicity of the torsion, `phase1` is the phase offset in radians, and `k1` is the force constant in kJ/mol. To add a second periodicity, just add three more attributes: `periodicity2`, `phase2`, and `k2`. **The maxium periodicity supported in DMFF is 6, which is different from OpenMM**. +The force field recognizes two different types of torsions: `Proper` and `Improper`. A proper torsion involves four atoms that are bonded in sequence: 1 to 2, 2 to 3, and 3 to 4. An improper torsion involves a central atom and three others that are bonded to it: atoms 2, 3, and 4 are all bonded to atom 1. `periodicity1` is the periodicity of the torsion, `phase1` is the phase offset in radians, and `k1` is the force constant in kJ/mol. To add a second periodicity, just add three more attributes: `periodicity2`, `phase2`, and `k2`. **The maxium periodicity supported in DMFF is 6, which is different from OpenMM**. You can also use wildcards when defining torsions. To do this, simply leave the type or class name for an atom empty. That will cause it to match any atom: @@ -90,19 +97,25 @@ When the tag has an attribute named `mask` and it's value set to `true`, this me # LennardJonesForce -## 1. Theory +## 1. Theory The Lennard-Jones intearction between two particles follows the potential form: -$$ E = 4\epsilon\left(\frac{\sigma^{12}}{r^{12}}-\frac{\sigma^{6}}{r^{6}}\right) $$ +$$ +E = 4\epsilon\left(\frac{\sigma^{12}}{r^{12}}-\frac{\sigma^{6}}{r^{6}}\right) +$$ where $r$ is the distance between two particles, $\epsilon$ is the depth of the potential wall and $\sigma$ defines the distance where the interaction energy is zero. The pairwise parameter $\sigma$ and $\epsilon$ are determined from the parameters of the individual particles using the Lorentz-Berthelot combining rule: -$$\sigma=\frac{\sigma_1+\sigma_2}{2}$$ +$$ +\sigma=\frac{\sigma_1+\sigma_2}{2} +$$ -$$\epsilon = \sqrt{\epsilon_1\epsilon_2} $$ +$$ +\epsilon = \sqrt{\epsilon_1\epsilon_2} +$$ ## 2. Frontend @@ -138,8 +151,10 @@ Note that the excluded pairs (interaction between particles seperated by 1 or 2 The form of the Coulomb interaction between each pair of particles depends on the `NonbondedMethod` in use. For NoCutoff, it is given by -$$ E = \frac{1}{4\pi\epsilon_0}\frac{q_1q_2}{r} $$ - +$$ +E = \frac{1}{4\pi\epsilon_0}\frac{q_1q_2}{r} +$$ + where $q_1$ and $q_2$ are the charges of the two particles, and $r$ is the distance between them. $\epsilon_0$ is the dielectric constant of vacuum. ## 2. Frontend @@ -202,14 +217,11 @@ Notice that the atomic charges are not specified in this tag because different f ``` - - # NonbondedForce ## 1. Theory -The `NonbondedForce` is a summary of `CoulombForce` and `LennardJonesForce` for consistency with OpenMM. With `NonbondedForce`, the force field library of OpenMM can be fluentely used in DMFF. The form of the Lennard-Jones and Coulomb interaction between each pair of particles depends on the `NonbondedMethod` in use. - +The `NonbondedForce` is a summary of `CoulombForce` and `LennardJonesForce` for consistency with OpenMM. With `NonbondedForce`, the force field library of OpenMM can be fluentely used in DMFF. The form of the Lennard-Jones and Coulomb interaction between each pair of particles depends on the `NonbondedMethod` in use. ## 2. Frontend @@ -224,6 +236,7 @@ To specify NonbondedForce interactions, include a tag that looks like this: ``` The attribute `coulomb14scale` and `lj14scale` specifies the scale factors between pairs of atoms separated by three bonds. The atomic charges are defined in a template-based manner since the `UseAttributeFromResidue` node is added. If node `UseAttributeFromResidue` does not exist, the atomic charges should be specified with `Atom` node, such as: + ```xml @@ -231,3 +244,183 @@ The attribute `coulomb14scale` and `lj14scale` specifies the scale factors betwe ``` +# CustomGBJaxForce + +## 1. Theory + +### Generalized Born Term + +The force consists of two energy terms: a Generalized Born Approximation term to represent the electrostatic interaction between the solute and solvent, and a surface area term to represent the free energy cost of solvating a neutral molecule. The Generalized Born energy is given by + +$$ +E = -\frac{1}{2}(\frac{1}{\epsilon_{solute}}-\frac{1}{\epsilon_{solvent}})\sum_{i,j}\frac{q_iq_j}{f_{GB}(d_{ij},R_i,R_j)} +$$ + +where the indices $i$ and $j$ run over all particles, $\epsilon_{solute}$ and $\epsilon_{solvent}$ are the dielectric constants of the solute and solvent respectively, $q_i$ is the charge of particle i, and $d_{ij}$ is the distance between particles i and j. And $f_{GB}(d_{ij},R_i,R_j)$ is defined as: + +$$ +f_{GB}(d_{ij},R_i,R_j)=[d^2_{ij}+R_iR_jexp(-\frac{d^2_{ij}}{4R_iR_j})]^{\frac{1}{2}} +$$ + +$R_i$ is the Born radius of particle i, which calculated as: + +$$ +R_i = \frac{1}{\rho_i^{-1}-r_i^{-1}tanh(\alpha\Psi_i-\beta\Psi_i^2+\gamma\Psi_i^3)} +$$ + +where $\alpha,\beta,\gamma$ are the $GB^{OBC}II$ parameters $\alpha=1, \beta=0.8,\gamma=4.85$. $\rho_i$ is the adjusted atomic radius of particle i, which is calculated from the atomic radius $r_i$ as $\rho_i=r_i-0.009$ nm. $\Psi_i$ is calculated as an integral over the van der Waals spheres of all particles outside particle i: + +$$ +\Psi_i=\frac{\rho_i}{4\pi}\int_{VDM}\theta(|r|-\rho_i)\frac{1}{|r|^4}d^3r +$$ + +where $\theta(r)$ is a step function that excludes the interior of particle i from the integral. + +### Surface Area Term + +The surface area term is given by: + +$$ +E=E_{SA}·4\pi\sum_i(r_i+r_{solvent})^2(\frac{r_i}{R_i})^6 +$$ + +where $r_i$ is the atomic radius of particle i, $r_i$ is its atomic radius, and $r_{solvent}$ is the solvent radius, which is taken to be 0.14 nm. The default value for the energy scale $E_{SA}$ is 2.25936 kJ/mol/nm2. + +## 2. Frontend + +The way to specify a CustomGBJaxForce in DMFF is the same as the way doing it in OpenMM with CustomGBForce: + +```xml + + + + + + + + step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C); + U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 = + scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009 + + + 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009 + + + 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*(1/soluteDielectric-1/solventDielectric)*charge^2/B + + + -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f; + f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2))) + + + + ... + +``` + +Every `` tag defines a rule for creating CustomGBForce interactions between atoms. Each tag may identify the atoms either by type (using the attributes `type1` and `type2`) or by class (using the attributes `class1` and `class2`). + + + +# CustomTorsionJaxForce + +## 1. Theory + +The torsion is represented by a truncated periodic Fourier series: + +$$ +E = \sum_{n=0}^{4} k_n(\cos(n\phi-\phi_{0n})) + shift +$$ + +where $\phi$ is the dihedral angle formed by four particles, $n$ is the periodicity, $\phi_{0n}$ is the phase offset $k_{n}$ is the force constant. To perserve the symmetry, $\phi_{0n}$ usually adopts a value of $0$ (for $n=1,3,5$) or $\pi$ ($n=2,4,6$), and it is recommened to follow these definitions and not to optimize them in force field development. + +## 2. Frontend + +The way to specify a custom torsion in DMFF is the same as the way doing it in OpenMM: + +```xml + + + + + + + + + + + + + + + + + + +``` + +Every child tag `` or `` defines a rule for creating periodic torsion interactions between sets of four atoms. Each tag may identify the atoms either by type (using the attributes `type1`, `type2`, ...) or by class (using the attributes `class1`, `class2`, ...). + +The force field recognizes two different types of torsions: `Proper` and `Improper`. A proper torsion involves four atoms that are bonded in sequence: 1 to 2, 2 to 3, and 3 to 4. An improper torsion involves a central atom and three others that are bonded to it: atoms 2, 3, and 4 are all bonded to atom 1. `per1` is the periodicity of the torsion, `phase1` is the phase offset in radians, and `k1` is the force constant in kJ/mol. To add a second periodicity, just add three more attributes: `per2`, `phase2`, and `k2`. **The maxium periodicity supported in DMFF is 6, which is different from OpenMM**. + +You can also use wildcards when defining torsions. To do this, simply leave the type or class name for an atom empty. That will cause it to match any atom: + +```xml + +``` + +When the tag has an attribute named `mask` and it's value set to `true`, this means the parameter is not trainable. Such information will be passed to `ParamSet.mask` (the corresponding mask value will be 0.0 if not trainable). + + +# Custom1_5BondJaxForce + +## 1. Theory + +The force is used to regulate the atoms relation between atom 1 to 5 in coarse-grained polyphosphate. + +$$ +E = \frac{1}{2}k(b-b_0)^2 +$$ + +where $k$ is the force constant, $b$ is the distance betweeen two particles that forming a bond and $b_0$ is the equilibrium bond length. Note that in some other MD softwares, the potential form adopts a slight different form: $E=k(b-b_0)^2$. Users should check which form to use and multiply (or divide) the force constant by 2. + +## 2. Frontend + +The way to specify a harmonic bond in DMFF is different from the way doing it in OpenMM, which requires add special force `openmm.CustomCompoundBondForce()` in coding: + +```xml + + + + + + + + +``` +When using this force, you need to add `openmm.CustomCompoundBondForce()` during your simulation: + +```python +h = Hamiltonian("CG.xml") +params = h.getParameters() +compoundBondForceParam = params["Custom1_5BondForce"] +length = compoundBondForceParam["length"] +k = compoundBondForceParam["k"] +system = ff.createSystem(pdb.topology, nonbondedMethod=NoCutoff) +customCompoundForce = openmm.CustomCompoundBondForce(2, "0.5*k*(distance(p1,p2)-length)^2") +customCompoundForce.addPerBondParameter("length") +customCompoundForce.addPerBondParameter("k") +for i, leng in enumerate(length): + customCompoundForce.addBond([i, i+4], [leng, k[i]]) +system.addForce(customCompoundForce) +``` + +Every `` tag defines a rule for creating harmonic bond interactions between 1 and 5 atoms. Each tag may identify the atoms by index (using the attributes `atomIndex1` and `atomIndex2`). `length` is the equilibrium bond length in $\mathrm{nm}$, and `k` is the force constant in $\mathrm{kJ/mol/nm^2}$. \ No newline at end of file From ce84c9cf9fba80864643921f19469c1ad0667026 Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 15:40:18 +0800 Subject: [PATCH 09/15] Add files via upload --- tests/test_classical/test_gbforce.py | 33 ++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/test_classical/test_gbforce.py diff --git a/tests/test_classical/test_gbforce.py b/tests/test_classical/test_gbforce.py new file mode 100644 index 000000000..8c55500ee --- /dev/null +++ b/tests/test_classical/test_gbforce.py @@ -0,0 +1,33 @@ +import pytest +import jax +import jax.numpy as jnp +import openmm.app as app +import openmm.unit as unit +import numpy as np +import numpy.testing as npt +from dmff.api import Hamiltonian +from dmff.common import nblist + + +@pytest.mark.parametrize( + "pdb, prm, value", + [ + ("../data/10p.pdb", "../data/1_5corrV2.xml", -11184.921239189738), + ("../data/pBox.pdb", "../data/polyp_amberImp.xml", -13914.34177591779), + ]) +def test_custom_gb_force(pdb, prm, value): + pdb = app.PDBFile(pdb) + h = Hamiltonian(prm) + potential = h.createPotential( + pdb.topology, + nonbondedMethod=app.NoCutoff + ) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) + rc = 6.0 + nbl = nblist.NeighborList(box, rc, potential.meta['cov_map']) + nbl.allocate(pos) + pairs = nbl.pairs + gbE = potential.getPotentialFunc(names=["CustomGBForce"]) + energy = gbE(pos, box, pairs, h.paramset) + npt.assert_almost_equal(energy, value, decimal=3) \ No newline at end of file From 98e780972a52b649f6fb29bf6d0b0abfa014d1a1 Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 15:40:54 +0800 Subject: [PATCH 10/15] Add files via upload --- tests/data/10p.pdb | 25 ++++++++ tests/data/1_5corrV2.xml | 98 ++++++++++++++++++++++++++++++ tests/data/pBox.pdb | 87 ++++++++++++++++++++++++++ tests/data/polyp_amberImp.xml | 111 ++++++++++++++++++++++++++++++++++ 4 files changed, 321 insertions(+) create mode 100644 tests/data/10p.pdb create mode 100644 tests/data/1_5corrV2.xml create mode 100644 tests/data/pBox.pdb create mode 100644 tests/data/polyp_amberImp.xml diff --git a/tests/data/10p.pdb b/tests/data/10p.pdb new file mode 100644 index 000000000..aab9da88e --- /dev/null +++ b/tests/data/10p.pdb @@ -0,0 +1,25 @@ +TITLE GRoups of Organic Molecules in ACtion for Science +REMARK THIS IS A SIMULATION BOX +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +MODEL 1 +ATOM 1 TP1 TE1 A 1 28.725 22.974 33.406 1.00 0.00 P +ATOM 2 P00 IN1 A 2 26.892 23.631 31.811 1.00 0.00 P +ATOM 3 P00 IN1 A 3 27.633 24.222 29.541 1.00 0.00 P +ATOM 4 P00 IN1 A 4 25.695 23.995 27.884 1.00 0.00 P +ATOM 5 P00 IN1 A 5 26.290 25.095 25.743 1.00 0.00 P +ATOM 6 P00 IN1 A 6 24.826 24.484 23.700 1.00 0.00 P +ATOM 7 P00 IN1 A 7 24.944 26.058 21.774 1.00 0.00 P +ATOM 8 P00 IN1 A 8 23.018 26.878 20.449 1.00 0.00 P +ATOM 9 P00 IN1 A 9 22.109 25.856 18.341 1.00 0.00 P +ATOM 10 TP1 TE1 A 10 19.868 26.802 17.355 1.00 0.00 P +TER +CONECT 1 2 +CONECT 2 3 +CONECT 3 4 +CONECT 4 5 +CONECT 5 6 +CONECT 6 7 +CONECT 7 8 +CONECT 8 9 +CONECT 9 10 +ENDMDL diff --git a/tests/data/1_5corrV2.xml b/tests/data/1_5corrV2.xml new file mode 100644 index 000000000..f64fbb47d --- /dev/null +++ b/tests/data/1_5corrV2.xml @@ -0,0 +1,98 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C); + U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 = + scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009 + + + 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009 + + + 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456*(1/soluteDielectric-1/solventDielectric)*charge^2/B + + + -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f; + f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2))) + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/data/pBox.pdb b/tests/data/pBox.pdb new file mode 100644 index 000000000..b1a0118e1 --- /dev/null +++ b/tests/data/pBox.pdb @@ -0,0 +1,87 @@ +TITLE polyten_GMX.gro created by acpype (v: 2022.6.6) on Wed Oct 5 08:03:17 2022 +REMARK THIS IS A SIMULATION BOX +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +MODEL 1 +HETATM 1 P00 TE1 1 28.652 12.361 14.349 1.00 0.00 P +HETATM 2 O01 TE1 1 28.511 11.590 12.887 1.00 0.00 O +HETATM 3 O02 TE1 1 28.857 11.267 15.580 1.00 0.00 O +HETATM 4 O03 TE1 1 29.865 13.487 14.329 1.00 0.00 O +HETATM 5 O04 TE1 1 26.965 13.160 14.653 1.00 0.00 O +HETATM 6 P00 IN1 2 26.260 14.805 14.965 1.00 0.00 P +HETATM 7 O01 IN1 2 26.855 15.339 16.380 1.00 0.00 O +HETATM 8 O02 IN1 2 26.508 15.706 13.631 1.00 0.00 O +HETATM 9 O00 IN2 3 24.431 14.542 15.100 1.00 0.00 O +HETATM 10 P00 IN1 4 23.086 15.420 16.142 1.00 0.00 P +HETATM 11 O01 IN1 4 22.975 14.492 17.468 1.00 0.00 O +HETATM 12 O02 IN1 4 23.530 16.970 16.231 1.00 0.00 O +HETATM 13 O00 IN2 5 21.421 15.307 15.326 1.00 0.00 O +HETATM 14 P00 IN1 6 19.909 16.539 15.340 1.00 0.00 P +HETATM 15 O01 IN1 6 20.000 17.324 16.743 1.00 0.00 O +HETATM 16 O02 IN1 6 20.063 17.265 13.902 1.00 0.00 O +HETATM 17 O00 IN2 7 18.227 15.697 15.269 1.00 0.00 O +HETATM 18 P00 IN1 8 16.526 16.321 15.943 1.00 0.00 P +HETATM 19 O01 IN1 8 16.373 15.511 17.337 1.00 0.00 O +HETATM 20 O02 IN1 8 16.580 17.927 15.852 1.00 0.00 O +HETATM 21 O00 IN2 9 15.002 15.739 14.963 1.00 0.00 O +HETATM 22 P00 IN1 10 13.371 16.585 14.465 1.00 0.00 P +HETATM 23 O01 IN1 10 13.092 17.763 15.532 1.00 0.00 O +HETATM 24 O02 IN1 10 13.581 16.841 12.885 1.00 0.00 O +HETATM 25 O00 IN2 11 11.772 15.515 14.643 1.00 0.00 O +HETATM 26 P00 IN1 12 10.263 15.275 13.556 1.00 0.00 P +HETATM 27 O01 IN1 12 10.068 16.593 12.650 1.00 0.00 O +HETATM 28 O02 IN1 12 10.482 13.800 12.922 1.00 0.00 O +HETATM 29 O00 IN2 13 8.598 15.069 14.540 1.00 0.00 O +HETATM 30 P00 IN1 14 6.866 15.556 14.061 1.00 0.00 P +HETATM 31 O01 IN1 14 6.615 17.010 14.733 1.00 0.00 O +HETATM 32 O02 IN1 14 6.677 15.310 12.476 1.00 0.00 O +HETATM 33 O00 IN2 15 5.578 14.441 14.933 1.00 0.00 O +HETATM 34 P00 IN1 16 3.855 13.938 14.474 1.00 0.00 P +HETATM 35 O01 IN1 16 3.150 15.098 13.579 1.00 0.00 O +HETATM 36 O02 IN1 16 3.961 12.437 13.850 1.00 0.00 O +HETATM 37 O04 TE1 17 2.946 13.817 16.041 1.00 0.00 O +HETATM 38 P00 TE1 17 1.214 13.377 16.663 1.00 0.00 P +HETATM 39 O01 TE1 17 0.217 12.929 15.419 1.00 0.00 O +HETATM 40 O02 TE1 17 0.655 14.736 17.433 1.00 0.00 O +HETATM 41 O03 TE1 17 1.442 12.135 17.739 1.00 0.00 O +TER +CONECT 1 2 +CONECT 1 3 +CONECT 1 4 +CONECT 1 5 +CONECT 5 6 +CONECT 6 7 +CONECT 6 8 +CONECT 6 9 +CONECT 9 10 +CONECT 10 11 +CONECT 10 12 +CONECT 10 13 +CONECT 13 14 +CONECT 14 15 +CONECT 14 16 +CONECT 14 17 +CONECT 17 18 +CONECT 18 19 +CONECT 18 20 +CONECT 18 21 +CONECT 21 22 +CONECT 22 23 +CONECT 22 24 +CONECT 22 25 +CONECT 25 26 +CONECT 26 27 +CONECT 26 28 +CONECT 26 29 +CONECT 29 30 +CONECT 30 31 +CONECT 30 32 +CONECT 30 33 +CONECT 33 34 +CONECT 34 35 +CONECT 34 36 +CONECT 34 37 +CONECT 37 38 +CONECT 38 39 +CONECT 38 40 +CONECT 38 41 +ENDMDL diff --git a/tests/data/polyp_amberImp.xml b/tests/data/polyp_amberImp.xml new file mode 100644 index 000000000..15f6c9721 --- /dev/null +++ b/tests/data/polyp_amberImp.xml @@ -0,0 +1,111 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + step(r+sr2-or1)*0.5*(1/L-1/U+0.25*(1/U^2-1/L^2)*(r-sr2*sr2/r)+0.5*log(L/U)/r+C); + U=r+sr2; C=2*(1/or1-1/L)*step(sr2-r-or1); L=max(or1, D); D=abs(r-sr2); sr2 = + scale2*or2; or1 = radius1-0.009; or2 = radius2-0.009 + + + 1/(1/or-tanh(1*psi-0.8*psi^2+4.85*psi^3)/radius); psi=I*or; or=radius-0.009 + + + 28.3919551*(radius+0.14)^2*(radius/B)^6-0.5*138.935456* + (1/soluteDielectric-1/solventDielectric)*charge^2/B + + + -138.935456*(1/soluteDielectric-1/solventDielectric)*charge1*charge2/f; + f=sqrt(r^2+B1*B2*exp(-r^2/(4*B1*B2))) + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file From 9a0acd180426492766f0a0f312361bdb4572ba18 Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 28 Dec 2023 15:48:27 +0800 Subject: [PATCH 11/15] Add files via upload --- tests/test_classical/test_gbforce.py | 46 ++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_classical/test_gbforce.py b/tests/test_classical/test_gbforce.py index 8c55500ee..b63da2a8f 100644 --- a/tests/test_classical/test_gbforce.py +++ b/tests/test_classical/test_gbforce.py @@ -30,4 +30,50 @@ def test_custom_gb_force(pdb, prm, value): pairs = nbl.pairs gbE = potential.getPotentialFunc(names=["CustomGBForce"]) energy = gbE(pos, box, pairs, h.paramset) + npt.assert_almost_equal(energy, value, decimal=3) + + +@pytest.mark.parametrize( + "pdb, prm, value", + [ + ("../data/10p.pdb", "../data/1_5corrV2.xml", 59.53033875302844), + ]) +def test_custom_torsion_force(pdb, prm, value): + pdb = app.PDBFile(pdb) + h = Hamiltonian(prm) + potential = h.createPotential( + pdb.topology, + nonbondedMethod=app.NoCutoff + ) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) + rc = 6.0 + nbl = nblist.NeighborList(box, rc, potential.meta['cov_map']) + nbl.allocate(pos) + pairs = nbl.pairs + gbE = potential.getPotentialFunc(names=["CustomTorsionForce"]) + energy = gbE(pos, box, pairs, h.paramset) + npt.assert_almost_equal(energy, value, decimal=3) + + +@pytest.mark.parametrize( + "pdb, prm, value", + [ + ("../data/10p.pdb", "../data/1_5corrV2.xml", 117.95416362791674), + ]) +def test_custom_1_5bond_force(pdb, prm, value): + pdb = app.PDBFile(pdb) + h = Hamiltonian(prm) + potential = h.createPotential( + pdb.topology, + nonbondedMethod=app.NoCutoff + ) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) + rc = 6.0 + nbl = nblist.NeighborList(box, rc, potential.meta['cov_map']) + nbl.allocate(pos) + pairs = nbl.pairs + gbE = potential.getPotentialFunc(names=["Custom1_5BondForce"]) + energy = gbE(pos, box, pairs, h.paramset) npt.assert_almost_equal(energy, value, decimal=3) \ No newline at end of file From 9478ab1a33fd82a3e5c42b2f55e9f2dbc567ed15 Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 11 Jan 2024 14:36:18 +0800 Subject: [PATCH 12/15] Update mbar.py --- dmff/mbar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dmff/mbar.py b/dmff/mbar.py index 4e55e7957..88194562b 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -4,7 +4,7 @@ except ImportError: import warnings warnings.warn("MDTraj not installed. MBAREstimator is not available.") - +# try: from pymbar import MBAR except ImportError: From 7d4fdbf130ca0772641fdddff54d116eb35f07dd Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 11 Jan 2024 14:40:47 +0800 Subject: [PATCH 13/15] Update mbar.py --- dmff/mbar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dmff/mbar.py b/dmff/mbar.py index 88194562b..d615ad34e 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -4,7 +4,7 @@ except ImportError: import warnings warnings.warn("MDTraj not installed. MBAREstimator is not available.") -# +## try: from pymbar import MBAR except ImportError: From 62febfa458d98688567dc476d88ae07b2f02f618 Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Thu, 11 Jan 2024 14:46:08 +0800 Subject: [PATCH 14/15] Update mbar.py --- dmff/mbar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dmff/mbar.py b/dmff/mbar.py index d615ad34e..4e55e7957 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -4,7 +4,7 @@ except ImportError: import warnings warnings.warn("MDTraj not installed. MBAREstimator is not available.") -## + try: from pymbar import MBAR except ImportError: From 6e6e3d5fff15852ae696bef51ad348864f895cc3 Mon Sep 17 00:00:00 2001 From: Ethan-Norch <56433364+Ethan-Norch@users.noreply.github.com> Date: Fri, 12 Jan 2024 10:16:05 +0800 Subject: [PATCH 15/15] Update test_gbforce.py --- tests/test_classical/test_gbforce.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_classical/test_gbforce.py b/tests/test_classical/test_gbforce.py index b63da2a8f..e13bcdb6f 100644 --- a/tests/test_classical/test_gbforce.py +++ b/tests/test_classical/test_gbforce.py @@ -12,8 +12,8 @@ @pytest.mark.parametrize( "pdb, prm, value", [ - ("../data/10p.pdb", "../data/1_5corrV2.xml", -11184.921239189738), - ("../data/pBox.pdb", "../data/polyp_amberImp.xml", -13914.34177591779), + ("tests/data/10p.pdb", "tests/data/1_5corrV2.xml", -11184.921239189738), + ("tests/data/pBox.pdb", "tests/data/polyp_amberImp.xml", -13914.34177591779), ]) def test_custom_gb_force(pdb, prm, value): pdb = app.PDBFile(pdb) @@ -36,7 +36,7 @@ def test_custom_gb_force(pdb, prm, value): @pytest.mark.parametrize( "pdb, prm, value", [ - ("../data/10p.pdb", "../data/1_5corrV2.xml", 59.53033875302844), + ("tests/data/10p.pdb", "tests/data/1_5corrV2.xml", 59.53033875302844), ]) def test_custom_torsion_force(pdb, prm, value): pdb = app.PDBFile(pdb) @@ -59,7 +59,7 @@ def test_custom_torsion_force(pdb, prm, value): @pytest.mark.parametrize( "pdb, prm, value", [ - ("../data/10p.pdb", "../data/1_5corrV2.xml", 117.95416362791674), + ("tests/data/10p.pdb", "tests/data/1_5corrV2.xml", 117.95416362791674), ]) def test_custom_1_5bond_force(pdb, prm, value): pdb = app.PDBFile(pdb) @@ -76,4 +76,4 @@ def test_custom_1_5bond_force(pdb, prm, value): pairs = nbl.pairs gbE = potential.getPotentialFunc(names=["Custom1_5BondForce"]) energy = gbE(pos, box, pairs, h.paramset) - npt.assert_almost_equal(energy, value, decimal=3) \ No newline at end of file + npt.assert_almost_equal(energy, value, decimal=3)