Skip to content

Commit

Permalink
adding finetuning script
Browse files Browse the repository at this point in the history
  • Loading branch information
Wengong Jin committed Apr 26, 2021
1 parent 6d37153 commit 95312da
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 5 deletions.
168 changes: 168 additions & 0 deletions finetune_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader

import rdkit
import math, random, sys
import numpy as np
import argparse
import os
from tqdm.auto import tqdm

import hgraph
from hgraph import HierVAE, common_atom_vocab, PairVocab
from chemprop.train import predict
from chemprop.data import MoleculeDataset, MoleculeDataLoader
from chemprop.data.utils import get_data, get_data_from_smiles
from chemprop.utils import load_args, load_checkpoint, load_scalers

param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))


class Chemprop(object):

def __init__(self, checkpoint_dir):
self.features_generator = ['rdkit_2d_normalized']
self.checkpoints, self.scalers, self.features_scalers = [], [], []
for root, _, files in os.walk(checkpoint_dir):
for fname in files:
if fname.endswith('.pt'):
fname = os.path.join(root, fname)
scaler, features_scaler = load_scalers(fname)
self.scalers.append(scaler)
self.features_scalers.append(features_scaler)
model = load_checkpoint(fname)
self.checkpoints.append(model)

def predict(self, smiles, batch_size=500):
test_data = get_data_from_smiles(
smiles=[[s] for s in smiles],
skip_invalid_smiles=False,
features_generator=self.features_generator
)
valid_indices = [i for i in range(len(test_data)) if test_data[i].mol[0] is not None]
full_data = test_data
test_data = MoleculeDataset([test_data[i] for i in valid_indices])
test_data_loader = MoleculeDataLoader(dataset=test_data, batch_size=batch_size)

sum_preds = np.zeros((len(test_data), 1))
for model, scaler, features_scaler in zip(self.checkpoints, self.scalers, self.features_scalers):
test_data.reset_features_and_targets()
if features_scaler is not None:
test_data.normalize_features(features_scaler)

model_preds = predict(
model=model,
data_loader=test_data_loader,
scaler=scaler
)
sum_preds += np.array(model_preds)

# Ensemble predictions
avg_preds = sum_preds / len(self.checkpoints)
avg_preds = avg_preds.squeeze(-1).tolist()

# Put zero for invalid smiles
full_preds = [0.0] * len(full_data)
for i, si in enumerate(valid_indices):
full_preds[si] = avg_preds[i]

return np.array(full_preds, dtype=np.float32)


if __name__ == "__main__":
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument('--train', required=True)
parser.add_argument('--vocab', required=True)
parser.add_argument('--atom_vocab', default=common_atom_vocab)
parser.add_argument('--save_dir', required=True)
parser.add_argument('--generative_model', required=True)
parser.add_argument('--chemprop_model', required=True)
parser.add_argument('--seed', type=int, default=7)

parser.add_argument('--rnn_type', type=str, default='LSTM')
parser.add_argument('--hidden_size', type=int, default=250)
parser.add_argument('--embed_size', type=int, default=250)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--latent_size', type=int, default=32)
parser.add_argument('--depthT', type=int, default=15)
parser.add_argument('--depthG', type=int, default=15)
parser.add_argument('--diterT', type=int, default=1)
parser.add_argument('--diterG', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)

parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--clip_norm', type=float, default=5.0)
parser.add_argument('--epoch', type=int, default=50)
parser.add_argument('--inner_epoch', type=int, default=10)
parser.add_argument('--threshold', type=float, default=0.3)
parser.add_argument('--nsample', type=int, default=10000)

args = parser.parse_args()
print(args)

torch.manual_seed(args.seed)
random.seed(args.seed)

with open(args.train) as f:
train_smiles = [line.strip("\r\n ") for line in f]

vocab = [x.strip("\r\n ").split() for x in open(args.vocab)]
args.vocab = PairVocab(vocab)

score_func = Chemprop(args.chemprop_model)
good_smiles = train_smiles

model = HierVAE(args).cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

print('Loading from checkpoint ' + args.generative_model)
model_state, optimizer_state, _, beta = torch.load(args.generative_model)
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)

for epoch in range(args.epoch):
good_smiles = sorted(set(good_smiles))
random.shuffle(good_smiles)
dataset = hgraph.MoleculeDataset(good_smiles, args.vocab, args.atom_vocab, args.batch_size)

print(f'Epoch {epoch} training...')
for _ in range(args.inner_epoch):
meters = np.zeros(6)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x:x[0], shuffle=True)
for batch in tqdm(dataloader):
model.zero_grad()
loss, kl_div, wacc, iacc, tacc, sacc = model(*batch, beta=beta)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
optimizer.step()
meters = meters + np.array([kl_div, loss.item(), wacc * 100, iacc * 100, tacc * 100, sacc * 100])

meters /= len(dataset)
print("Beta: %.3f, KL: %.2f, loss: %.3f, Word: %.2f, %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (beta, meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], param_norm(model), grad_norm(model)))

ckpt = (model.state_dict(), optimizer.state_dict(), epoch, beta)
torch.save(ckpt, os.path.join(args.save_dir, f"model.ckpt.{epoch}"))

print(f'Epoch {epoch} decoding...')
decoded_smiles = []
with torch.no_grad():
for _ in tqdm(range(args.nsample // args.batch_size)):
outputs = model.sample(args.batch_size, greedy=True)
decoded_smiles.extend(outputs)

print(f'Epoch {epoch} filtering...')
scores = score_func.predict(decoded_smiles)
good_entries = [(s,p) for s,p in zip(decoded_smiles, scores) if p >= args.threshold]
print(f'Discovered {len(good_entries)} active molecules')
good_smiles += [s for s,p in good_entries]
with open(os.path.join(args.save_dir, f"new_molecules.{epoch}"), 'w') as f:
for s, p in zip(decoded_smiles, scores):
print(s, p, file=f)

4 changes: 2 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@

model = HierVAE(args).cuda()

model.load_state_dict(torch.load(args.model))
model.load_state_dict(torch.load(args.model)[0])
model.eval()

torch.manual_seed(args.seed)
random.seed(args.seed)

with torch.no_grad():
for _ in tqdm(range(args.nsample // args.batch_size)):
smiles_list = model.sample(args.batch_size)
smiles_list = model.sample(args.batch_size, greedy=True)
for _,smiles in enumerate(smiles_list):
print(smiles)

19 changes: 18 additions & 1 deletion hgraph/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,24 @@
from hgraph.chemutils import get_leaves
from hgraph.mol_graph import MolGraph


class MoleculeDataset(Dataset):

def __init__(self, data, vocab, avocab, batch_size):
self.batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)]
safe_data = []
for mol_s in data:
hmol = MolGraph(mol_s)
ok = True
for node,attr in hmol.mol_tree.nodes(data=True):
smiles = attr['smiles']
ok &= attr['label'] in vocab.vmap
for i,s in attr['inter_label']:
ok &= (smiles, s) in vocab.vmap
if ok:
safe_data.append(mol_s)

print(f'After pruning {len(data)} -> {len(safe_data)}')
self.batches = [safe_data[i : i + batch_size] for i in range(0, len(safe_data), batch_size)]
self.vocab = vocab
self.avocab = avocab

Expand All @@ -20,6 +34,7 @@ def __len__(self):
def __getitem__(self, idx):
return MolGraph.tensorize(self.batches[idx], self.vocab, self.avocab)


class MolEnumRootDataset(Dataset):

def __init__(self, data, vocab, avocab):
Expand Down Expand Up @@ -50,6 +65,7 @@ def __getitem__(self, idx):
else:
return None


class MolPairDataset(Dataset):

def __init__(self, data, vocab, avocab, batch_size):
Expand All @@ -66,6 +82,7 @@ def __getitem__(self, idx):
y = MolGraph.tensorize(y, self.vocab, self.avocab)
return x + y


class DataFolder(object):

def __init__(self, data_folder, batch_size, shuffle=True):
Expand Down
4 changes: 2 additions & 2 deletions hgraph/hgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def rsample(self, z_vecs, W_mean, W_var, perturb=True):
z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon if perturb else z_mean
return z_vecs, kl_loss

def sample(self, batch_size):
def sample(self, batch_size, greedy):
root_vecs = torch.randn(batch_size, self.latent_size).cuda()
return self.decoder.decode((root_vecs, root_vecs, root_vecs), greedy=True, max_decode_step=150)
return self.decoder.decode((root_vecs, root_vecs, root_vecs), greedy=greedy, max_decode_step=150)

def reconstruct(self, batch):
graphs, tensors, _ = batch
Expand Down

0 comments on commit 95312da

Please sign in to comment.