-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add sGNN generator fixed a few problems in ADMPPmeGenerator * remove debugging codes
- Loading branch information
Showing
38 changed files
with
737 additions
and
208 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .classical import * | ||
from .admp import * | ||
from .admp import * | ||
from .ml import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from ..api.topology import DMFFTopology | ||
from ..api.paramset import ParamSet | ||
from ..api.hamiltonian import _DMFFGenerators | ||
from ..utils import DMFFException, isinstance_jnp | ||
from ..utils import jit_condition | ||
import numpy as np | ||
import jax | ||
import jax.numpy as jnp | ||
import openmm.app as app | ||
import openmm.unit as unit | ||
import pickle | ||
|
||
from ..sgnn.graph import MAX_VALENCE, TopGraph, from_pdb | ||
from ..sgnn.gnn import MolGNNForce, prm_transform_f2i | ||
|
||
|
||
class SGNNGenerator: | ||
def __init__(self, ffinfo: dict, paramset: ParamSet): | ||
|
||
self.name = "SGNNForce" | ||
self.ffinfo = ffinfo | ||
paramset.addField(self.name) | ||
self.key_type = None | ||
|
||
self.file = self.ffinfo["Forces"][self.name]["meta"]["file"] | ||
self.nn = int(self.ffinfo["Forces"][self.name]["meta"]["nn"]) | ||
self.pdb = self.ffinfo["Forces"][self.name]["meta"]["pdb"] | ||
|
||
# load ML potential parameters | ||
with open(self.file, 'rb') as ifile: | ||
params = pickle.load(ifile) | ||
|
||
# convert to jnp array | ||
for k in params: | ||
params[k] = jnp.array(params[k]) | ||
# set mask to all true | ||
paramset.addParameter(params[k], k, field=self.name, mask=jnp.ones(params[k].shape)) | ||
|
||
# mask = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape), params) | ||
# paramset.addParameter(params, "params", field=self.name, mask=mask) | ||
|
||
|
||
def getName(self) -> str: | ||
return self.name | ||
|
||
def overwrite(self, paramset): | ||
# do not use xml to handle ML potentials | ||
# for ML potentials, xml only documents param file path | ||
# so for ML potentials, overwrite function overwrites the file directly | ||
with open(self.file, 'wb') as ofile: | ||
pickle.dump(paramset[self.name], ofile) | ||
return | ||
|
||
def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs): | ||
self.G = from_pdb(self.pdb) | ||
n_atoms = topdata.getNumAtoms() | ||
self.model = MolGNNForce(self.G, nn=self.nn) | ||
n_layers = self.model.n_layers | ||
def potential_fn(positions, box, pairs, params): | ||
# convert unit to angstrom | ||
positions = positions * 10 | ||
box = box * 10 | ||
prms = prm_transform_f2i(params[self.name], n_layers) | ||
return self.model.get_energy(positions, box, prms) | ||
|
||
self._jaxPotential = potential_fn | ||
return potential_fn | ||
|
||
def getJaxPotential(self): | ||
return self._jaxPotential | ||
|
||
|
||
_DMFFGenerators["SGNNForce"] = SGNNGenerator | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
test_backend/model1.pickle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
<ForceField> | ||
<AtomTypes> | ||
<Type class="CT" element="C" mass="12.0107" name="1" /> | ||
<Type class="HC" element="H" mass="1.00784" name="2" /> | ||
<Type class="OS" element="O" mass="15.999" name="3" /> | ||
<Type class="CT" element="C" mass="12.0107" name="4" /> | ||
<Type class="HC" element="H" mass="1.00784" name="5" /> | ||
</AtomTypes> | ||
<Residues> | ||
<Residue name="TER"> | ||
<Atom name="C00" type="1" /> | ||
<Atom name="H01" type="2" /> | ||
<Atom name="H02" type="2" /> | ||
<Atom name="O03" type="3" /> | ||
<Atom name="C04" type="4" /> | ||
<Atom name="H05" type="5" /> | ||
<Atom name="H06" type="5" /> | ||
<Atom name="H07" type="5" /> | ||
<Bond from="0" to="1" /> | ||
<Bond from="0" to="2" /> | ||
<Bond from="0" to="3" /> | ||
<Bond from="3" to="4" /> | ||
<Bond from="4" to="5" /> | ||
<Bond from="4" to="6" /> | ||
<Bond from="4" to="7" /> | ||
<ExternalBond atomName="C00" /> | ||
</Residue> | ||
<Residue name="INT"> | ||
<Atom name="C00" type="1" /> | ||
<Atom name="H01" type="2" /> | ||
<Atom name="H02" type="2" /> | ||
<Atom name="O03" type="3" /> | ||
<Atom name="C04" type="1" /> | ||
<Atom name="H05" type="2" /> | ||
<Atom name="H06" type="2" /> | ||
<Bond from="0" to="1" /> | ||
<Bond from="0" to="2" /> | ||
<Bond from="0" to="3" /> | ||
<Bond from="3" to="4" /> | ||
<Bond from="4" to="5" /> | ||
<Bond from="4" to="6" /> | ||
<ExternalBond atomName="C00" /> | ||
<ExternalBond atomName="C04" /> | ||
</Residue> | ||
</Residues> | ||
<SGNNForce file="model1.pickle" pdb="peg4.pdb" nn="1"/> | ||
</ForceField> | ||
|
Oops, something went wrong.