Skip to content

Commit

Permalink
molecule generation code ready
Browse files Browse the repository at this point in the history
  • Loading branch information
Wengong Jin committed Apr 21, 2021
1 parent 0bf7b08 commit 6d37153
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 297 deletions.
51 changes: 38 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,66 @@ Our paper is at https://arxiv.org/pdf/2002.03230.pdf
First install the dependencies via conda:
* PyTorch >= 1.0.0
* networkx
* RDKit
* RDKit >= 2019.03
* numpy
* Python >= 3.6

And then run `pip install .`

## Molecule Generation
The molecule generation code is in the `generation/` folder.
## Data Format
* For graph generation, each line of a training file is a SMILES string of a molecule
* For graph translation, each line of a training file is a pair of molecules (molA, molB) that are similar to each other but molB has better chemical properties. Please see `data/qed/train_pairs.txt`. The test file is a list of molecules to be optimized. Please see `data/qed/test.txt`.

## Graph translation Data Format
* The training file should contain pairs of molecules (molA, molB) that are similar to each other but molB has better chemical properties. Please see `data/qed/train_pairs.txt`.
* The test file is a list of molecules to be optimized. Please see `data/qed/test.txt`.
## Graph generation training procedure
1. Extract substructure vocabulary from a given set of molecules:
```
python get_vocab.py --ncpu 16 < data/qed/mols.txt > vocab.txt
```

2. Preprocess training data:
```
python preprocess.py --train data/qed/mols.txt --vocab data/qed/vocab.txt --ncpu 16 --mode single
mkdir train_processed
mv tensor* train_processed/
```

3. Train graph generation model
```
mkdir ckpt/generation
python train_generator.py --train train_processed/ --vocab data/qed/vocab.txt --save_dir ckpt/generation
```

4. Sample molecules from a model checkpoint
```
python generate.py --vocab data/qed/vocab.txt --model ckpt/generation/model.5 --nsamples 1000
```

## Graph translation training procedure
1. Extract substructure vocabulary from a given set of molecules:
```
python get_vocab.py < data/qed/mols.txt > vocab.txt
python get_vocab.py --ncpu 16 < data/qed/mols.txt > vocab.txt
```
Please replace `data/qed/mols.txt` with your molecules data file.
Please replace `data/qed/mols.txt` with your molecules.

2. Preprocess training data:
```
python preprocess.py --train data/qed/train_pairs.txt --vocab data/qed/vocab.txt --ncpu 16 < data/qed/train_pairs.txt
python preprocess.py --train data/qed/train_pairs.txt --vocab data/qed/vocab.txt --ncpu 16
mkdir train_processed
mv tensor* train_processed/
```
Please replace `--train` and `--vocab` with training and vocab file.

3. Train the model:
```
mkdir models/
python gnn_train.py --train train_processed/ --vocab data/qed/vocab.txt --save_dir models/
mkdir ckpt/translation
python train_translator.py --train train_processed/ --vocab data/qed/vocab.txt --save_dir ckpt/translation
```

4. Make prediction on your lead compounds (you can use any model checkpoint, here we use model.5 for illustration)
```
python decode.py --test data/qed/valid.txt --vocab data/qed/vocab.txt --model models/model.5 --num_decode 20 > results.csv
python translate.py --test data/qed/valid.txt --vocab data/qed/vocab.txt --model ckpt/translation/model.5 --num_decode 20 > results.csv
```

## Polymer generation
The polymer generation code is in the `polymer/` folder. The polymer generation code is similar to `train_generator.py`, but the substructures are tailored for polymers.
For generating regular drug like molecules, we recommend to use `train_generator.py` in the root directory.

64 changes: 0 additions & 64 deletions cond_decode.py

This file was deleted.

71 changes: 0 additions & 71 deletions decode.py

This file was deleted.

88 changes: 20 additions & 68 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math, random, sys
import numpy as np
import argparse
from tqdm import tqdm

from hgraph import *
import rdkit
Expand All @@ -15,89 +16,40 @@
lg.setLevel(rdkit.RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument('--train', required=True)
parser.add_argument('--vocab', required=True)
parser.add_argument('--atom_vocab', default=common_atom_vocab)
parser.add_argument('--save_dir', required=True)
parser.add_argument('--load_epoch', type=int, default=-1)
parser.add_argument('--model', required=True)

parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--nsample', type=int, default=10000)

parser.add_argument('--rnn_type', type=str, default='LSTM')
parser.add_argument('--hidden_size', type=int, default=260)
parser.add_argument('--embed_size', type=int, default=260)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--latent_size', type=int, default=24)
parser.add_argument('--depthT', type=int, default=20)
parser.add_argument('--depthG', type=int, default=20)
parser.add_argument('--hidden_size', type=int, default=250)
parser.add_argument('--embed_size', type=int, default=250)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--latent_size', type=int, default=32)
parser.add_argument('--depthT', type=int, default=15)
parser.add_argument('--depthG', type=int, default=15)
parser.add_argument('--diterT', type=int, default=1)
parser.add_argument('--diterG', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)

parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--clip_norm', type=float, default=20.0)
parser.add_argument('--beta', type=float, default=0.3)

parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--anneal_rate', type=float, default=0.9)
parser.add_argument('--print_iter', type=int, default=50)
parser.add_argument('--save_iter', type=int, default=-1)

args = parser.parse_args()
print(args)

vocab = [x.strip("\r\n ").split() for x in open(args.vocab)]
args.vocab = PairVocab(vocab)

model = HierVAE(args).cuda()

for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)

if args.load_epoch >= 0:
model.load_state_dict(torch.load(args.save_dir + "/model." + str(args.load_epoch)))

print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate)

param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))

total_step = 0
beta = args.beta
meters = np.zeros(6)

for epoch in range(args.load_epoch + 1, args.epoch):
dataset = DataFolder(args.train, args.batch_size)

for batch in dataset:
total_step += 1
model.zero_grad()
loss, kl_div, wacc, iacc, tacc, sacc = model(*batch, beta=beta)

loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
optimizer.step()
model.load_state_dict(torch.load(args.model))
model.eval()

meters = meters + np.array([kl_div, loss.item(), wacc * 100, iacc * 100, tacc * 100, sacc * 100])
torch.manual_seed(args.seed)
random.seed(args.seed)

if total_step % args.print_iter == 0:
meters /= args.print_iter
print("[%d] Beta: %.3f, KL: %.2f, loss: %.3f, Word: %.2f, %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], param_norm(model), grad_norm(model)))
sys.stdout.flush()
meters *= 0

if args.save_iter >= 0 and total_step % args.save_iter == 0:
n_iter = total_step // args.save_iter - 1
torch.save(model.state_dict(), args.save_dir + "/model." + str(n_iter))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
with torch.no_grad():
for _ in tqdm(range(args.nsample // args.batch_size)):
smiles_list = model.sample(args.batch_size)
for _,smiles in enumerate(smiles_list):
print(smiles)

del dataset
if args.save_iter == -1:
torch.save(model.state_dict(), args.save_dir + "/model." + str(epoch))
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
23 changes: 12 additions & 11 deletions get_vocab.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
import argparse
from hgraph import *
from rdkit import Chem
from multiprocessing import Pool
from collections import Counter

def process(data):
vocab = set()
Expand All @@ -11,26 +11,27 @@ def process(data):
hmol = MolGraph(s)
for node,attr in hmol.mol_tree.nodes(data=True):
smiles = attr['smiles']
vocab[attr['label']] += 1
vocab.add( attr['label'] )
for i,s in attr['inter_label']:
vocab[(smiles, s)] += 1
vocab.add( (smiles, s) )
return vocab

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--ncpu', type=int, default=1)
args = parser.parse_args()

data = [mol for line in sys.stdin for mol in line.split()[:2]]
data = list(set(data))

ncpu = 15
batch_size = len(data) // ncpu + 1
batch_size = len(data) // args.ncpu + 1
batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)]

pool = Pool(ncpu)
pool = Pool(args.ncpu)
vocab_list = pool.map(process, batches)
vocab = [(x,y) for vocab in vocab_list for x,y in vocab]
vocab = list(set(vocab))

vocab = Counter()
for c in vocab_list:
vocab |= c

for (x,y),c in vocab:
for x,y in sorted(vocab):
print(x, y)
3 changes: 1 addition & 2 deletions hgraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@
from hgraph.encoder import HierMPNEncoder
from hgraph.decoder import HierMPNDecoder
from hgraph.vocab import Vocab, PairVocab, common_atom_vocab
from hgraph.hgnn import HierGNN, HierVGNN, HierCondVGNN
from hgraph.hgnn import HierVAE, HierVGNN, HierCondVGNN
from hgraph.dataset import MoleculeDataset, MolPairDataset, DataFolder, MolEnumRootDataset
from hgraph.stereo import restore_stereo
3 changes: 3 additions & 0 deletions hgraph/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(self, data_folder, batch_size, shuffle=True):
self.batch_size = batch_size
self.shuffle = shuffle

def __len__(self):
return len(self.data_files) * 1000

def __iter__(self):
for fn in self.data_files:
fn = os.path.join(self.data_folder, fn)
Expand Down
Loading

0 comments on commit 6d37153

Please sign in to comment.