Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add PyG-based GAT implementation. #67

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ The input passed to this model should be a `dict` with the following keys (based
* `pos`: Tensor of coordinates for each atom, shape of `(n,3)`
* `z`: Tensor of bool labels of whether each atom is a protein atom (`False`) or ligand atom (`True`), shape of `(n,)`
* `GAT`
* `g`: DGL graph object
* `x`: Tensor of input atom (node) features, shape of `(n,feats)`
* `edge_index`: Tensor giving source (first row) and dest (second row) atom indices, shape of `(2,n_bonds)`

The prediction can then be generated simply with:
```python
Expand Down
2 changes: 0 additions & 2 deletions devtools/conda-envs/mtenn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
2 changes: 0 additions & 2 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
- fsspec
Expand Down
32 changes: 25 additions & 7 deletions docs/docs/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,37 @@ Below, we detail a basic example of building a default Graph Attention model and

.. code-block:: python

from dgllife.utils import CanonicalAtomFeaturizer, SMILESToBigraph
from mtenn.config import GATModelConfig
import rdkit.Chem as Chem
import torch

# Build model with GAT defaults
model = GATModelConfig().build()

# Build graph from SMILES
# Build mol
smiles = "CCCC"
g = SMILESToBigraph(
add_self_loop=True,
node_featurizer=CanonicalAtomFeaturizer(),
)(smiles)
mol = Chem.MolFromSmiles(smiles)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a convenience function to do this easily for user, easy to mess up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added one in asapdiscovery for us to use, but since there's no one right way to featurize a molecule I didn't want to add anything opinionated in here


# Get atomic numbers and bond indices (both directions)
atomic_nums = [a.GetAtomicNum() for a in mol.GetAtoms()]
bond_idxs = [
atom_pair
for bond in mol.GetBonds()
for atom_pair in (
(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()),
(bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()),
)
]
# Add self bonds
bond_idxs += [(a.GetIdx(), a.GetIdx()) for a in mol.GetAtoms()]

# Encode atomic numbers as one-hot, assume max num of 100
node_feats = torch.nn.functional.one_hot(
torch.tensor(atomic_nums), num_classes=100
).to(dtype=torch.float)
# Format bonds in correct shape
edge_index = torch.tensor(bond_idxs).t()

# Make a prediction
pred, _ = model({"g": g})
pred, _ = model({"x": node_feats, "edge_index": edge_index})

2 changes: 0 additions & 2 deletions docs/requirements.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
- pydantic >=1.10.8,<2.0.0a0
Expand Down
4 changes: 1 addition & 3 deletions environment-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,5 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- rdkit
- ase
4 changes: 1 addition & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
- ase
- ase
143 changes: 133 additions & 10 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class ModelType(StringEnum):
"""

GAT = "GAT"
pyg_gat = "pyg_gat"
dgl_gat = "dgl_gat"
schnet = "schnet"
e3nn = "e3nn"
visnet = "visnet"
Expand Down Expand Up @@ -403,6 +405,127 @@ def _check_grouped(values):


class GATModelConfig(ModelConfigBase):
"""
Class for constructing a GAT ML model. Default values here are based on the values
in DGL-LifeSci.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DGL-LifeSci gone now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the defaults are still based on the defaults in that package, even though we're not using their code anymore

"""

model_type: ModelType = Field(ModelType.GAT, const=True)

in_channels: int = Field(
-1,
description=(
"Input size. Can be left as -1 (default) to interpret based on "
"first forward call."
),
)
hidden_channels: int = Field(32, description="Hidden embedding size.")
num_layers: int = Field(2, description="Number of GAT layers.")
v2: bool = Field(False, description="Use GATv2Conv layer instead of GATConv.")
dropout: float = Field(0, description="Dropout probability.")
heads: int = Field(4, description="Number of attention heads for each GAT layer.")
negative_slope: float = Field(
0.2, description="LeakyReLU angle of the negative slope."
)

def _build(self, mtenn_params={}):
"""
Build an ``mtenn`` GAT ``Model`` from this config.

:meta public:

Parameters
----------
mtenn_params : dict, optional
Dictionary that stores the ``Readout`` objects for the individual
predictions and for the combined prediction, and the ``Combination`` object
in the case of a multi-pose model. These are all constructed the same for all
``Model`` types, so we can just handle them in the base class. Keys in the
dict will be:

* "combination": :py:mod:`Combination <mtenn.combination>`

* "pred_readout": :py:mod:`Readout <mtenn.readout>` for individual
pose predictions

* "comb_readout": :py:mod:`Readout <mtenn.readout>` for combined
prediction (in the case of a multi-pose model)

although the combination-related entries will be ignore because this is a
ligand-only model.

Returns
-------
mtenn.model.Model
Model constructed from the config
"""
from mtenn.conversion_utils.gat import GAT

model = GAT(
in_channels=self.in_channels,
hidden_channels=self.hidden_channels,
num_layers=self.num_layers,
v2=self.v2,
dropout=self.dropout,
heads=self.heads,
negative_slope=self.negative_slope,
)

pred_readout = mtenn_params.get("pred_readout", None)
return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True)


class PyGGATModelConfig(GATModelConfig):
model_type: ModelType = Field(ModelType.pyg_gat, const=True)

def _build(self, mtenn_params={}):
"""
Build an ``mtenn`` PyGGAT ``Model`` from this config.

:meta public:

Parameters
----------
mtenn_params : dict, optional
Dictionary that stores the ``Readout`` objects for the individual
predictions and for the combined prediction, and the ``Combination`` object
in the case of a multi-pose model. These are all constructed the same for all
``Model`` types, so we can just handle them in the base class. Keys in the
dict will be:

* "combination": :py:mod:`Combination <mtenn.combination>`

* "pred_readout": :py:mod:`Readout <mtenn.readout>` for individual
pose predictions

* "comb_readout": :py:mod:`Readout <mtenn.readout>` for combined
prediction (in the case of a multi-pose model)

although the combination-related entries will be ignore because this is a
ligand-only model.

Returns
-------
mtenn.model.Model
Model constructed from the config
"""
from mtenn.conversion_utils.pyg_gat import PyGGAT

model = PyGGAT(
in_channels=self.in_channels,
hidden_channels=self.hidden_channels,
num_layers=self.num_layers,
v2=self.v2,
dropout=self.dropout,
heads=self.heads,
negative_slope=self.negative_slope,
)

pred_readout = mtenn_params.get("pred_readout", None)
return PyGGAT.get_model(model=model, pred_readout=pred_readout, fix_device=True)


class DGLGATModelConfig(ModelConfigBase):
"""
Class for constructing a graph attention ML model. Note that there are two methods
for defining the size of the model:
Expand Down Expand Up @@ -440,7 +563,7 @@ class GATModelConfig(ModelConfigBase):
"biases": bool,
} #: :meta private:

model_type: ModelType = Field(ModelType.GAT, const=True)
model_type: ModelType = Field(ModelType.dgl_gat, const=True)

in_feats: int = Field(
_CanonicalAtomFeaturizer().feat_size(),
Expand Down Expand Up @@ -532,7 +655,7 @@ class GATModelConfig(ModelConfigBase):
_from_num_layers = False

@root_validator(pre=False)
def massage_into_lists(cls, values) -> GATModelConfig:
def massage_into_lists(cls, values) -> DGLGATModelConfig:
"""
Validator to handle unifying all the values into the proper list forms based on
the rules described in the class docstring.
Expand Down Expand Up @@ -621,9 +744,9 @@ def _build(self, mtenn_params={}):
mtenn.model.Model
Model constructed from the config
"""
from mtenn.conversion_utils.gat import GAT
from mtenn.conversion_utils.dgl_gat import DGLGAT

model = GAT(
model = DGLGAT(
in_feats=self.in_feats,
hidden_feats=self.hidden_feats,
num_heads=self.num_heads,
Expand All @@ -638,9 +761,9 @@ def _build(self, mtenn_params={}):
)

pred_readout = mtenn_params.get("pred_readout", None)
return GAT.get_model(model=model, pred_readout=pred_readout, fix_device=True)
return DGLGAT.get_model(model=model, pred_readout=pred_readout, fix_device=True)

def _update(self, config_updates={}) -> GATModelConfig:
def _update(self, config_updates={}) -> DGLGATModelConfig:
"""
GAT-specific implementation of updating logic. Need to handle stuff specially
to make sure that the original method of specifying parameters (either from a
Expand All @@ -656,15 +779,15 @@ def _update(self, config_updates={}) -> GATModelConfig:

Returns
-------
GATModelConfig
New ``GATModelConfig`` object
DGLGATModelConfig
New ``DGLGATModelConfig`` object
"""
orig_config = self.dict()
if self._from_num_layers or ("num_layers" in config_updates):
# If originally generated from num_layers, want to pull out the first entry
# in each list param so it can be re-broadcast with (potentially) new
# num_layers
for param_name in GATModelConfig.LIST_PARAMS.keys():
for param_name in DGLGATModelConfig.LIST_PARAMS.keys():
orig_config[param_name] = orig_config[param_name][0]

# Get new config by overwriting old stuff with any new stuff
Expand All @@ -676,7 +799,7 @@ def _update(self, config_updates={}) -> GATModelConfig:
):
new_config["activations"] = None

return GATModelConfig(**new_config)
return DGLGATModelConfig(**new_config)


class SchNetModelConfig(ModelConfigBase):
Expand Down
Loading
Loading