Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lsj2408 committed Oct 4, 2022
1 parent 137fc4e commit f761db0
Show file tree
Hide file tree
Showing 1,179 changed files with 196,428 additions and 0 deletions.
137 changes: 137 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# JetBrains PyCharm IDE
.idea/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# macOS dir files
.DS_Store

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# Checkpoints
checkpoints

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# Generated files
/fairseq/temporal_convolution_tbc
/fairseq/modules/*_layer/*_forward.cu
/fairseq/modules/*_layer/*_backward.cu
/fairseq/version.py

# data
data-bin/
datasets/
logs/
# reranking
/examples/reranking/rerank_data

# Cython-generated C++ source files
/fairseq/data/data_utils_fast.cpp
/fairseq/data/token_block_utils_fast.cpp

# VSCODE
.vscode/ftp-sync.json
.vscode/settings.json

# Experimental Folder
experimental/*

# Weights and Biases logs
wandb/
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "fairseq/model_parallel/megatron"]
path = fairseq/model_parallel/megatron
url = https://github.com/ngoyal2707/Megatron-LM
branch = fairseq
Empty file modified LICENSE
100644 → 100755
Empty file.
135 changes: 135 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# One Transformer Can Understand Both 2D & 3D Molecular Data

This repository is the official implementation of “One Transformer Can Understand Both 2D & 3D Molecular Data”, based on the official implementation of [Graphormer](https://github.com/microsoft/Graphormer) and [Fairseq](https://github.com/facebookresearch/fairseq) in [PyTorch](https://github.com/pytorch/pytorch).

## Overview

![arch](docs/arch.jpg)

Transformer-M is a versatile and effective molecular model that can take molecular data of 2D or 3D formats as input and generate meaningful semantic representations. Using the standard Transformer as the backbone architecture, Transformer-M develops two separated channels to encode 2D and 3D structural information and incorporate them with the atom features in the network modules. When the input data is in a particular format, the corresponding channel will be activated, and the other will be disabled. Empirical results show that our Transformer-M can achieve strong performance on 2D and 3D tasks simultaneously, which is the first step toward general-purpose molecular models in chemistry.

## Results on PCQM4Mv2, OGB-Large Scale Challenge

![](docs/Table1.png)
🚀**Note:** **PCQM4Mv2** is also the benchmark dataset of the graph-level track in the **2nd OGB-LSC** at [**NeurIPS 2022 competition track**](https://ogb.stanford.edu/neurips2022/). As non-participants, we open source all the codes and model weights, and sincerely welcome participants to use our model. Looking forward to your feedback!

## Installation

- Clone this repository

```shell
git clone https://github.com/lsj2408/Transformer-M.git
```

- Install the dependencies (Using anaconda, tested with CUDA version 11.0)

```shell
cd ./Transformer-M
conda env create -f requirement.yaml
conda activate Transformer-M
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch_geometric==1.6.3
pip install torch_scatter==2.0.7
pip install torch_sparse==0.6.9
pip install azureml-defaults
pip install rdkit-pypi cython
python setup.py build_ext --inplace
python setup_cython.py build_ext --inplace
pip install -e .
pip install --upgrade protobuf==3.20.1
pip install --upgrade tensorboard==2.9.1
pip install --upgrade tensorboardX==2.5.1
```

## Checkpoints

| Model | File Size | Update Date | Valid MAE on PCQM4Mv2 | Download Link |
| ----- | --------- | ------------ | --------------------- | -------------------------------------------------------- |
| L12 | 189MB | Oct 04, 2022 | 0.0785 | https://1drv.ms/u/s!AgZyC7AzHtDBdWUZttg6N2TsOxw?e=sUOhox |
| L18 | 270MB | Oct 04, 2022 | 0.0772 | https://1drv.ms/u/s!AgZyC7AzHtDBdrY59-_mP38jsCg?e=URoyUK |

```shell
# create paths to checkpoints for evaluation

# download the above model weights (L12.pt, L18.pt) to ./
cd Transformer-M
mkdir -p logs/L12
mkdir -p logs/L18
mv L12.pt logs/L12/
mv L18.pt logs/L18/
```

## Datasets

- Preprocessed data: [download link](https://1drv.ms/u/s!AgZyC7AzHtDBeIDqE61u1ZEMv_8?e=3g428e)

```shell
# create paths to datasets for evaluation/training

# download the above compressed datasets (pcqm4mv2-pos.zip) to ./
unzip pcqm4mv2-pos.zip -d ./datasets
mv ./datasets ./Transformer-M
```

- You can also directly execute the evaluation/training code to process data from scratch.

## Evaluation

```shell
export data_path='./datasets/pcq-pos' # path to data
export save_path='./logs/{folder_to_checkpoints}' # path to checkpoints, e.g., ./logs/L12

export layers=12 # set layers=18 for 18-layer model
export hidden_size=768 # dimension of hidden layers
export ffn_size=768 # dimension of feed-forward layers
export num_head=32 # number of attention heads
export num_3d_bias_kernel=128 # number of Gaussian Basis kernels
export batch_size=256 # batch size for a single gpu
export dataset_name="PCQM4M-LSC-V2-3D"
export add_3d="true"
bash evaluate.sh
```

## Training

```shell
# L12. Valid MAE: 0.0785
export data_path='./datasets/pcq-pos' # path to data
export save_path='./logs/' # path to logs

export lr=2e-4 # peak learning rate
export warmup_steps=150000 # warmup steps
export total_steps=1500000 # total steps
export layers=12 # set layers=18 for 18-layer model
export hidden_size=768 # dimension of hidden layers
export ffn_size=768 # dimension of feed-forward layers
export num_head=32 # number of attention heads
export batch_size=32 # batch size for a single gpu
export dropout=0.0
export act_dropout=0.1
export attn_dropout=0.1
export weight_decay=0.0
export droppath_prob=0.1 # probability of stochastic depth
export noise_scale=0.2 # noise scale
export mode_prob="0.2,0.2,0.6" # mode distribution for {2D+3D, 2D, 3D}
export dataset_name="PCQM4M-LSC-V2-3D"
export add_3d="true"
export num_3d_bias_kernel=128 # number of Gaussian Basis kernels
bash train.sh
```

Our model is trained on 4 NVIDIA Tesla A100 GPUs (40GB). The time cost for an epoch is around 10 minutes.

## References

TBA

## Contact

Shengjie Luo ([email protected])

Sincerely appreciate your suggestions on our work!

## License

This project is licensed under the terms of the MIT license. See [LICENSE](https://github.com/lsj2408/Transformer-M/blob/main/LICENSE) for additional details.
1 change: 1 addition & 0 deletions Transformer-M/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import criterions
1 change: 1 addition & 0 deletions Transformer-M/criterions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import graph_prediction
105 changes: 105 additions & 0 deletions Transformer-M/criterions/graph_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from dataclasses import dataclass
import math
from omegaconf import II

import torch
import torch.nn as nn
from fairseq import metrics, modules, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
import os


@dataclass
class GraphPredictionConfig(FairseqDataclass):
tpu: bool = II("common.tpu")


@register_criterion("graph_prediction", dataclass=GraphPredictionConfig)
class GraphPredictionLoss(FairseqCriterion):
"""
Implementation for the loss used in masked graph model (MGM) training.
"""

def __init__(self, cfg: GraphPredictionConfig, task):
super().__init__(task)
self.tpu = cfg.tpu
self.noise_scale = task.cfg.noise_scale

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
sample_size = sample["nsamples"]

with torch.no_grad():
natoms = sample["net_input"]["batched_data"]['x'].shape[1]

# add gaussian noise
ori_pos = sample['net_input']['batched_data']['pos']
noise = torch.randn(ori_pos.shape).to(ori_pos) * self.noise_scale
noise_mask = (ori_pos == 0.0).all(dim=-1, keepdim=True)
noise = noise.masked_fill_(noise_mask, 0.0)
sample['net_input']['batched_data']['pos'] = ori_pos + noise


model_output = model(**sample["net_input"])
logits, node_output = model_output[0], model_output[1]
logits = logits[:,0,:]
targets = model.get_targets(sample, [logits])

loss = nn.L1Loss(reduction='sum')(logits, targets)

if node_output is not None:
node_mask = (node_output == 0.0).all(dim=-1).all(dim=-1)[:, None, None] + noise_mask
node_output = node_output.masked_fill_(node_mask, 0.0)

node_output_loss = (1.0 - nn.CosineSimilarity(dim=-1)(node_output.to(torch.float32), noise.masked_fill_(node_mask, 0.0).to(torch.float32)))
node_output_loss = node_output_loss.masked_fill_(node_mask.squeeze(-1), 0.0).sum(dim=-1).to(torch.float16)

tgt_count = (~node_mask).squeeze(-1).sum(dim=-1).to(node_output_loss)
tgt_count = tgt_count.masked_fill_(tgt_count == 0.0, 1.0)
node_output_loss = (node_output_loss / tgt_count).sum() * 1
else:
node_output_loss = (noise - noise).sum()

logging_output = {
"loss": loss.data,
"node_output_loss": node_output_loss.data,
"total_loss": loss.data + node_output_loss.data,
"sample_size": sample_size,
"nsentences": sample_size,
"ntokens": natoms,
}
return loss + node_output_loss, sample_size, logging_output

@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
node_output_loss_sum = sum(log.get("node_output_loss", 0) for log in logging_outputs)
total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)

metrics.log_scalar(
"loss", loss_sum / sample_size, sample_size, round=6
)
metrics.log_scalar(
"node_output_loss", node_output_loss_sum / sample_size, sample_size, round=6
)
metrics.log_scalar(
"total_loss", total_loss_sum / sample_size, sample_size, round=6
)

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
Loading

0 comments on commit f761db0

Please sign in to comment.