Skip to content

Commit

Permalink
vector gating + atom3d
Browse files Browse the repository at this point in the history
  • Loading branch information
bjing2016 committed Jul 14, 2021
1 parent 5263b49 commit 85ca918
Show file tree
Hide file tree
Showing 7 changed files with 1,057 additions and 23 deletions.
100 changes: 94 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

Implementation of equivariant GVP-GNNs as described in [Learning from Protein Structure with Geometric Vector Perceptrons](https://openreview.net/forum?id=1YLJDvSx6J4) by B Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror.

This repository serves two purposes. If you would like to use the GVP architecture for structural biology tasks, we provide building blocks for models and data pipelines. If you are specifically interested in protein design as described in the paper, we provide scripts for training and testing models.
**UPDATE:** Also includes equivariant GNNs with vector gating as described in [Equivariant Graph Neural Networks for 3D Macromolecular Structure](https://arxiv.org/abs/2106.03843) by B Jing, S Eismann, P Soni, and RO Dror.

**Note:** This repository is an implementation in PyTorch Geometric emphasizing usability and flexibility. The original code for the paper, in TensorFlow, can be found [here](https://github.com/drorlab/gvp). We thank Pratham Soni for his contributions to the implementation in PyTorch.
Scripts for training/testing/sampling on protein design and training/testing on all [ATOM3D](https://arxiv.org/abs/2012.04035) tasks are provided.

**Note:** This implementation is in PyTorch Geometric. The original TensorFlow code, which is not maintained, can be found [here](https://github.com/drorlab/gvp).

<p align="center"><img src="schematic.png" width="500"></p>

Expand All @@ -20,15 +22,17 @@ This repository serves two purposes. If you would like to use the GVP architectu
* tqdm==4.38.0
* numpy==1.19.4
* sklearn==0.24.1
* atom3d==0.2.1

While we have not tested with other versions, any reasonably recent versions of these requirements should work.

## General usage

We provide classes in three modules:
* `gvp`: core GVP modules and GVP-GNN layers
* `gvp.data`: data pipeline functionality for both general use and protein design
* `gvp.models`: implementations of MQA and CPD models as described in the paper
* `gvp.data`: data pipelines for both general use and protein design
* `gvp.models`: implementations of MQA and CPD models
* `gvp.atom3d`: models and data pipelines for ATOM3D

The core modules in `gvp` are meant to be as general as possible, but you will likely have to modify `gvp.data` and `gvp.models` for your specific application, with the existing classes serving as examples.

Expand All @@ -52,6 +56,11 @@ in_dims = scalars_in, vectors_in
out_dims = scalars_out, vectors_out
gvp_ = gvp.GVP(in_dims, out_dims)
```
To use vector gating, pass in `vector_gate=True` and the appropriate activations.
```
gvp_ = gvp.GVP(in_dims, out_dims,
activations=(F.relu, None), vector_gate=True)
```
The classes `gvp.Dropout` and `gvp.LayerNorm` implement vector-channel dropout and layer norm, while using normal dropout and layer norm for scalar channels. Both expect inputs and return outputs of form `(s, V)`, but will also behave like their scalar-valued counterparts if passed a single tensor.
```
dropout = gvp.Dropout(drop_rate=0.1)
Expand Down Expand Up @@ -86,7 +95,7 @@ edge_index = torch.randint(0, 5, (2, 10), device=device)
conv = gvp.GVPConv(in_dims, out_dims, edge_dims)
out = conv(nodes, edge_index, edges)
```
The class GVPConvLayer is a `nn.Module` that forms messages using a `GVPConv` and updates the node embeddings as described in the paper. Because the updates are residual, the dimensionality of the embeddings are not changed.
The class `GVPConvLayer` is a `nn.Module` that forms messages using a `GVPConv` and updates the node embeddings as described in the paper. Because the updates are residual, the dimensionality of the embeddings are not changed.
```
layer = gvp.GVPConvLayer(node_dims, edge_dims)
nodes = layer(nodes, edge_index, edges)
Expand All @@ -97,6 +106,8 @@ nodes_static = gvp.randn(n=5, in_dims)
layer = gvp.GVPConvLayer(node_dims, edge_dims, autoregressive=True)
nodes = layer(nodes, edge_index, edges, autoregressive_x=nodes_static)
```
Both `GVPConv` and `GVPConvLayer` accept arguments `activations` and `vector_gate` to use vector gating.

### Loading data

The class `gvp.data.ProteinGraphDataset` transforms protein backbone structures into featurized graphs. Following [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design), we use a JSON/dictionary format to specify backbone structures:
Expand Down Expand Up @@ -204,6 +215,76 @@ sample = model.sample(nodes, protein.edge_index, # shape = (n_samples, n_nodes)
```
The output will be an int tensor, with mappings corresponding to those used when training the model.

## ATOM3D
We provide models and dataloaders for all ATOM3D tasks in `gvp.atom3d`, as well as a training and testing script in `run_atom3d.py`. This also supports loading pretrained weights for transfer learning experiments.

### Models / data loaders
The GVP-GNNs for ATOM3D are supplied in `gvp.atom3d` and are named after each task: `gvp.atom3d.MSPModel`, `gvp.atom3d.PPIModel`, etc. All of these extend the base class `gvp.atom3d.BaseModel`. These classes take no arguments at initialization, take in a `torch_geometric.data.Batch` representation of a batch of structures, and return an output corresponding to the task. Details vary based on the exact task---see the docstrings.
```
psr_model = gvp.atom3d.PSRModel()
```
`gvp.atom3d` also includes data loaders to produce `torch_geometric.data.Batch` objects from an underlying `atom3d.datasets.LMDBDataset`. In the case of all tasks except PPI and RES, these are in the form of callable transform objects---`gvp.atom3d.SMPTransform`, `gvp.atom3d.RSRTransform`, etc---which should be passed into the constructor of a `atom3d.datasets.LMDBDataset`:
```
psr_dataset = atom3d.datasets.LMDBDataset(path_to_dataset,
transform=gvp.atom3d.PSRTransform())
```
On the other hand, `gvp.atom3d.PPIDataset` and `gvp.atom3d.RESDataset` take the place of / are wrappers around the `atom3d.datasets.LMDBDataset`:
```
ppi_dataset = gvp.atom3d.PPIDataset(path_to_dataset)
res_dataset = gvp.atom3d.RESDataset(path_to_dataset, path_to_split) # see docstring
```
All datasets must be then wrapped in a `torch_geometric.data.DataLoader`:
```
psr_dataloader = torch_geometric.data.DataLoader(psr_dataset, batch_size=batch_size)
```
The dataloaders can be directly iterated over to yield `torch_geometric.data.Batch` objects, which can then be passed into the models.
```
for batch in psr_dataloader:
pred = psr_model(batch) # pred.shape = (batch_size,)
```

### Training / testing

To run training / testing on ATOM3D, download the datasets as described [here](https://www.atom3d.ai/). Modify the function `get_datasets` in `run_atom3d.py` with the paths to the datasets. Then run:
```
$ python run_atom3d.py -h
usage: run_atom3d.py [-h] [--num-workers N] [--smp-idx IDX]
[--lba-split SPLIT] [--batch SIZE] [--train-time MINUTES]
[--val-time MINUTES] [--epochs N] [--test PATH]
[--lr RATE] [--load PATH]
TASK
positional arguments:
TASK {PSR, RSR, PPI, RES, MSP, SMP, LBA, LEP}
optional arguments:
-h, --help show this help message and exit
--num-workers N number of threads for loading data, default=4
--smp-idx IDX label index for SMP, in range 0-19
--lba-split SPLIT identity cutoff for LBA, 30 (default) or 60
--batch SIZE batch size, default=8
--train-time MINUTES maximum time between evaluations on valset,
default=120 minutes
--val-time MINUTES maximum time per evaluation on valset, default=20
minutes
--epochs N training epochs, default=50
--test PATH evaluate a trained model
--lr RATE learning rate
--load PATH initialize first 2 GNN layers with pretrained weights
```
For example:
```
# train a model
python run_atom3d.py PSR
# train a model with pretrained weights
python run_atom3d.py PSR --load PATH
# evaluate a model
python run_atom3d.py PSR --test PATH
```

## Acknowledgements
Portions of the input data pipeline were adapted from [Ingraham, et al, NeurIPS 2019](https://github.com/jingraham/neurips19-graph-protein-design). We thank Pratham Soni for portions of the implementation in PyTorch.

Expand All @@ -217,4 +298,11 @@ Portions of the input data pipeline were adapted from [Ingraham, et al, NeurIPS
year={2021},
url={https://openreview.net/forum?id=1YLJDvSx6J4}
}
```
@article{jing2021equivariant,
title={Equivariant Graph Neural Networks for 3D Macromolecular Structure},
author={Jing, Bowen and Eismann, Stephan and Soni, Pratham N and Dror, Ron O},
journal={arXiv preprint arXiv:2106.03843},
year={2021}
}
```
54 changes: 39 additions & 15 deletions gvp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch
import torch, functools
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
Expand Down Expand Up @@ -85,18 +85,22 @@ class GVP(nn.Module):
:param out_dims: tuple (n_scalar, n_vector)
:param h_dim: intermediate number of vector channels, optional
:param activations: tuple of functions (scalar_act, vector_act)
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
'''
def __init__(self, in_dims, out_dims, h_dim=None,
activations=(F.relu, torch.sigmoid)):
activations=(F.relu, torch.sigmoid), vector_gate=False):
super(GVP, self).__init__()
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.vector_gate = vector_gate
if self.vi:
self.h_dim = h_dim or max(self.vi, self.vo)
self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
self.ws = nn.Linear(self.h_dim + self.si, self.so)
if self.vo:
self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
else:
self.ws = nn.Linear(self.si, self.so)

Expand All @@ -119,7 +123,13 @@ def forward(self, x):
if self.vo:
v = self.wv(vh)
v = torch.transpose(v, -1, -2)
if self.vector_act:
if self.vector_gate:
if self.vector_act:
gate = self.wsv(self.vector_act(s))
else:
gate = self.wsv(s)
v = v * torch.sigmoid(gate).unsqueeze(-1)
elif self.vector_act:
v = v * self.vector_act(
_norm_no_nan(v, axis=-1, keepdims=True))
else:
Expand Down Expand Up @@ -214,28 +224,35 @@ class GVPConv(MessagePassing):
:param n_layers: number of GVPs in the message function
:param module_list: preconstructed message function, overrides n_layers
:param aggr: should be "add" if some incoming edges are masked, as in
a masked autoregressive decoder architecture
a masked autoregressive decoder architecture, otherwise "mean"
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
'''
def __init__(self, in_dims, out_dims, edge_dims,
n_layers=3, module_list=None, aggr="mean"):
n_layers=3, module_list=None, aggr="mean",
activations=(F.relu, torch.sigmoid), vector_gate=False):
super(GVPConv, self).__init__(aggr=aggr)
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.se, self.ve = edge_dims

GVP_ = functools.partial(GVP,
activations=activations, vector_gate=vector_gate)

module_list = module_list or []
if not module_list:
if n_layers == 1:
module_list.append(
GVP((2*self.si + self.se, 2*self.vi + self.ve),
GVP_((2*self.si + self.se, 2*self.vi + self.ve),
(self.so, self.vo), activations=(None, None)))
else:
module_list.append(
GVP((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
)
for i in range(n_layers - 2):
module_list.append(GVP(out_dims, out_dims))
module_list.append(GVP(out_dims, out_dims,
module_list.append(GVP_(out_dims, out_dims))
module_list.append(GVP_(out_dims, out_dims,
activations=(None, None)))
self.message_func = nn.Sequential(*module_list)

Expand Down Expand Up @@ -276,26 +293,33 @@ class GVPConvLayer(nn.Module):
:param autoregressive: if `True`, this `GVPConvLayer` will be used
with a different set of input node embeddings for messages
where src >= dst
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
'''
def __init__(self, node_dims, edge_dims,
n_message=3, n_feedforward=2, drop_rate=.1,
autoregressive=False):
autoregressive=False,
activations=(F.relu, torch.sigmoid), vector_gate=False):

super(GVPConvLayer, self).__init__()
self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
aggr="add" if autoregressive else "mean")
aggr="add" if autoregressive else "mean",
activations=activations, vector_gate=vector_gate)
GVP_ = functools.partial(GVP,
activations=activations, vector_gate=vector_gate)
self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])

ff_func = []
if n_feedforward == 1:
ff_func.append(GVP(node_dims, node_dims, activations=(None, None)))
ff_func.append(GVP_(node_dims, node_dims, activations=(None, None)))
else:
hid_dims = 4*node_dims[0], 2*node_dims[1]
ff_func.append(GVP(node_dims, hid_dims))
ff_func.append(GVP_(node_dims, hid_dims))
for i in range(n_feedforward-2):
ff_func.append(GVP(hid_dims, hid_dims))
ff_func.append(GVP(hid_dims, node_dims, activations=(None, None)))
ff_func.append(GVP_(hid_dims, hid_dims))
ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None)))
self.ff_func = nn.Sequential(*ff_func)

def forward(self, x, edge_index, edge_attr,
Expand Down
Loading

0 comments on commit 85ca918

Please sign in to comment.