Skip to content

Commit

Permalink
training and evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
awfderry committed Jun 4, 2021
1 parent 50b9b6b commit 5701486
Show file tree
Hide file tree
Showing 28 changed files with 868 additions and 190 deletions.
2 changes: 1 addition & 1 deletion atom3d/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .datasets import LMDBDataset, PDBDataset, SilentDataset, load_dataset, make_lmdb_dataset, get_file_list, extract_coordinates_as_numpy_arrays, download_dataset
from .datasets import LMDBDataset, PDBDataset, PTGDataset, SilentDataset, load_dataset, make_lmdb_dataset, get_file_list, extract_coordinates_as_numpy_arrays, download_dataset
23 changes: 23 additions & 0 deletions atom3d/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, IterableDataset
import torch_geometric.data as ptg
import torch

import atom3d.util.rosetta as ar
import atom3d.util.file as fi
Expand Down Expand Up @@ -261,6 +263,7 @@ def __getitem__(self, index: int):
if self._gdb:
item['labels'] = data
item['freq'] = freq
item['smiles'] = smiles
if self._transform:
item = self._transform(item)
return item
Expand Down Expand Up @@ -313,6 +316,26 @@ def __getitem__(self, index: int):
item = self._transform(item)
return item

class PTGDataset(ptg.Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(PTGDataset, self).__init__(root, transform, pre_transform)


@property
def processed_dir(self):
return self.root

@property
def processed_file_names(self):
return ['data_1.pt']


def len(self):
return len(os.listdir(self.processed_dir))

def get(self, idx):
data = torch.load(os.path.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data

def serialize(x, serialization_format):
"""
Expand Down
62 changes: 57 additions & 5 deletions atom3d/datasets/smp/prepare_lmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import os
import re
import sys
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem.rdchem import HybridizationType
from rdkit import RDConfig
from rdkit.Chem import ChemicalFeatures

import click
import numpy as np
Expand All @@ -13,9 +18,51 @@
import atom3d.util.file as fi
import atom3d.util.formats as fo


logger = logging.getLogger(__name__)
fdef_name = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)

def _get_rdkit_data(smiles):
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
bonds_df = fo.get_bonds_list_from_mol(mol)
type_mapping = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}

type_idx = []
atomic_number = []
acceptor = []
donor = []
aromatic = []
sp = []
sp2 = []
sp3 = []
num_hs = []
for atom in mol.GetAtoms():
type_idx.append(type_mapping[atom.GetSymbol()])
atomic_number.append(atom.GetAtomicNum())
donor.append(0)
acceptor.append(0)
aromatic.append(1 if atom.GetIsAromatic() else 0)
hybridization = atom.GetHybridization()
sp.append(1 if hybridization == HybridizationType.SP else 0)
sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
num_hs.append(atom.GetTotalNumHs(includeNeighbors=True))

feats = factory.GetFeaturesForMol(mol)
for j in range(0, len(feats)):
if feats[j].GetFamily() == 'Donor':
node_list = feats[j].GetAtomIds()
for k in node_list:
donor[k] = 1
elif feats[j].GetFamily() == 'Acceptor':
node_list = feats[j].GetAtomIds()
for k in node_list:
acceptor[k] = 1
atom_feats = [atomic_number, acceptor, donor, aromatic, sp, sp2, sp3, num_hs]

return bonds_df, atom_feats


def _add_data_with_subtracted_thermochem_energy(x):
"""
Expand All @@ -40,6 +87,7 @@ def _add_data_with_subtracted_thermochem_energy(x):
cv_atom = data[14] - np.sum([c * thchem_en[el][4] for el, c in counts.items()]) # Cv
# Append new data
x['labels'] += [u0_atom, u_atom, h_atom, g_atom, cv_atom]
x['bonds'], x['atom_feats'] = _get_rdkit_data(x['smiles'])
# Delete the file path
del x['file_path']
return x
Expand All @@ -56,6 +104,10 @@ def _write_split_indices(split_txt, lmdb_ds, output_txt):
f.write(str('\n'.join([str(i) for i in split_indices])))
return split_indices

def bond_filter(item):
if len(item['bonds']) == 0:
return True
return False

@click.command(help='Prepare SMP dataset')
@click.argument('input_file_path', type=click.Path())
Expand All @@ -78,7 +130,7 @@ def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
lmdb_path = os.path.join(output_root, 'all')
logger.info(f'Creating lmdb dataset into {lmdb_path:}...')
dataset = da.load_dataset(file_list, filetype, transform=_add_data_with_subtracted_thermochem_energy)
da.make_lmdb_dataset(dataset, lmdb_path)
da.make_lmdb_dataset(dataset, lmdb_path, filter_fn=bond_filter)
# Only continue if we want to write split datasets
if not split:
return
Expand All @@ -91,9 +143,9 @@ def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt):
indices_test = _write_split_indices(test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt'))
# Write the split datasets
train_dataset, val_dataset, test_dataset = spl.split(lmdb_ds, indices_train, indices_val, indices_test)
da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'))
da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'))
da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'))
da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'), filter_fn=bond_filter)
da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'), filter_fn=bond_filter)
da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'), filter_fn=bond_filter)


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions atom3d/datasets/smp/prepare_lmdb.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

#SBATCH --time=24:00:00
#SBATCH --mem=20G
#SBATCH --partition=rondror
#SBATCH --qos=high_p
#SBATCH --partition=rbaltman,owners
# # SBATCH --qos=high_p


# Directory definitions
OUT_DIR=/oak/stanford/groups/rondror/projects/atom3d/lmdb/small_molecule_properties
OUT_DIR=/scratch/users/aderry/lmdb/atom3d/small_molecule_properties
XYZ_DIR=$SCRATCH/dsgdb9nsd
START_DIR=$(pwd)

Expand Down Expand Up @@ -36,6 +36,6 @@ python prepare_lmdb.py $XYZ_DIR $OUT_DIR --split \
--val_txt splits/ids_validation.txt \
--test_txt splits/ids_test.txt
# Remove the raw data
rm $SCRATCH/dsgdb9nsd.xyz.tar.bz2
rm -r $XYZ_DIR
# rm $SCRATCH/dsgdb9nsd.xyz.tar.bz2
# rm -r $XYZ_DIR

31 changes: 21 additions & 10 deletions atom3d/util/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import scipy.spatial as ss
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
from torch_sparse import coalesce

import atom3d.util.formats as fo

Expand Down Expand Up @@ -58,7 +60,7 @@ def prot_df_to_graph(df, feat_col='element', allowable_feats=prot_atoms, edge_di
return node_feats, edges, edge_weights, node_pos


def mol_df_to_graph(df, bonds=None, allowable_atoms=mol_atoms, edge_dist_cutoff=4.5):
def mol_df_to_graph(df, bonds=None, allowable_atoms=mol_atoms, edge_dist_cutoff=4.5, onehot_edges=True):
"""
Converts molecule in dataframe to a graph compatible with Pytorch-Geometric
Expand All @@ -71,27 +73,33 @@ def mol_df_to_graph(df, bonds=None, allowable_atoms=mol_atoms, edge_dist_cutoff=
:return: Tuple containing \n
- node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``.
- edges (torch.LongTensor): Edges from chemical bond graph in COO format.
- edge_index (torch.LongTensor): Edges from chemical bond graph in COO format.
- edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5.
- node_pos (torch.FloatTensor): x-y-z coordinates of each node.
"""
node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
N = df.shape[0]
bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3}

if bonds is not None:
bond_data = torch.FloatTensor(bonds.to_numpy())
edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0)
edges = edge_tuples.t().long().contiguous()
edge_feats = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0)
edge_index = edge_tuples.t().long().contiguous()
if onehot_edges:
bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist()))
edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float)
edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
else:
edge_attr = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0)
else:
kd_tree = ss.KDTree(node_pos)
edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
edges = torch.LongTensor(edge_tuples).t().contiguous()
edges = to_undirected(edges)
edge_feats = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1)

edge_index = torch.LongTensor(edge_tuples).t().contiguous()
edge_index = to_undirected(edge_index)
edge_attr = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1)
node_feats = torch.FloatTensor([one_of_k_encoding_unk(e, allowable_atoms) for e in df['element']])

return node_feats, edges, edge_feats, node_pos
return node_feats, edge_index, edge_attr, node_pos


def combine_graphs(graph1, graph2, edges_between=True, edges_between_dist=4.5):
Expand Down Expand Up @@ -158,7 +166,10 @@ def edges_between_graphs(pos1, pos2, dist=4.5):
continue
for j in contacts:
edges.append((i, j + pos1.shape[0]))
edge_weights.append(np.linalg.norm(pos1[i] - pos2[j]))
edges.append((j + pos1.shape[0], i))
d = 1.0 / (np.linalg.norm(pos1[i] - pos2[j]) + 1e-5)
edge_weights.append(d)
edge_weights.append(d)

edges = torch.LongTensor(edges).t().contiguous()
edge_weights = torch.FloatTensor(edge_weights).view(-1)
Expand Down
39 changes: 39 additions & 0 deletions atom3d/util/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,42 @@ def evaluate_average(results, metric=r2, verbose=True, select=None):
print(' Validation: %7.3f +/- %7.3f'%summary_va)
print(' Test: %7.3f +/- %7.3f'%summary_te)
return summary_tr, summary_va, summary_te

def _per_target_mean(res, metric):
all_res = []
for targets, predictions in res:
all_res.append(metric(targets, predictions))
return np.mean(all_res)

def evaluate_per_target_average(results, metric=r2, verbose=True):
"""
Calculate metric for training, validation and test data, averaged over all replicates.
"""
# Initialization
reps = results.keys()
metric_tr = np.empty(len(reps))
metric_va = np.empty(len(reps))
metric_te = np.empty(len(reps))
# Go through training repetitions
for r, rep in enumerate(results.keys()):
# Load the predictions
train = results[rep]['train']
val = results[rep]['valid']
test = results[rep]['test']

# Calculate Statistics
metric_tr[r] = _per_target_mean(train, metric)
metric_va[r] = _per_target_mean(val, metric)
metric_te[r] = _per_target_mean(test, metric)

if verbose: print(' - %s - Training: %7.3f - Validation: %7.3f - Test: %7.3f'%(rep, metric_tr[r], metric_va[r], metric_te[r]))
# Mean and corresponding standard deviations
summary_tr = (np.mean(metric_tr), np.std(metric_tr))
summary_va = (np.mean(metric_va), np.std(metric_va))
summary_te = (np.mean(metric_te), np.std(metric_te))
if verbose:
print('---')
print(' Training: %7.3f +/- %7.3f'%summary_tr)
print(' Validation: %7.3f +/- %7.3f'%summary_va)
print(' Test: %7.3f +/- %7.3f'%summary_te)
return summary_tr, summary_va, summary_te
45 changes: 37 additions & 8 deletions atom3d/util/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import torch
import numpy as np
import pandas as pd
import scipy as sp
import scipy.stats as stats

Expand Down Expand Up @@ -44,27 +45,55 @@ def __init__(self, name, reps=[1,2,3]):

def get_prediction(self, prediction_fn):
"""
Reads targets and prediction.
TODO: Implement this!
Reads targets and prediction
"""
targets, predict = None, None
pr_data = torch.load(prediction_fn)
targets = np.array( pr_data['targets'] )
predict = np.array( pr_data['predictions'] )
return targets, predict

def get_all_predictions(self):
results = {}
for r, rep in enumerate(self.reps):
prediction_fn = self.name + '-rep'+str(int(rep))+'.best'
# Load the Cormorant predictions
targets_tr, predict_tr = self.get_prediction(prediction_fn+'.train.pt')
targets_va, predict_va = self.get_prediction(prediction_fn+'.valid.pt')
targets_va, predict_va = self.get_prediction(prediction_fn+'.val.pt')
targets_te, predict_te = self.get_prediction(prediction_fn+'.test.pt')
targets = {'train':targets_tr, 'valid':targets_va, 'test':targets_te}
predict = {'train':predict_tr, 'valid':predict_va, 'test':predict_te}
results['rep'+str(int(rep))] = {'targets':targets, 'predict':predict}
return results


def get_predictions_by_target(self, prediction_fn):
results_df = pd.DataFrame(torch.load(prediction_fn))
per_target = []
for key, val in results_df.groupby(['target']):
# Ignore target with 2 decoys only since the correlations are
# not really meaningful.
if val.shape[0] < 3:
continue
true = val['true'].astype(float).to_numpy()
pred = val['pred'].astype(float).to_numpy()
per_target.append((true, pred))
global_true = results_df['true'].astype(float).to_numpy()
global_pred = results_df['pred'].astype(float).to_numpy()
return global_true, global_pred, per_target

def get_target_specific_predictions(self):
"""For use with PSR/RSR. Here `target` refers to the protein target, not the prediction target."""
results = {'global':{}, 'per_target':{}}
for r, rep in enumerate(self.reps):
prediction_fn = self.name + '-rep'+str(int(rep))+'.best'
targets_tr, predict_tr, per_target_tr = self.get_predictions_by_target(prediction_fn+'.train.pt')
targets_va, predict_va, per_target_va = self.get_predictions_by_target(prediction_fn+'.val.pt')
targets_te, predict_te, per_target_te = self.get_predictions_by_target(prediction_fn+'.test.pt')
targets = {'train':targets_tr, 'valid':targets_va, 'test':targets_te}
predict = {'train':predict_tr, 'valid':predict_va, 'test':predict_te}
per_target = {'train': per_target_tr, 'valid':per_target_va, 'test':per_target_te}
results['global']['rep'+str(int(rep))] = {'targets':targets, 'predict':predict}
results['per_target']['rep'+str(int(rep))] = per_target
return results


class ResultsENN():

Expand Down
4 changes: 2 additions & 2 deletions atom3d/util/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def prot_graph_transform(item, atom_keys=['atoms'], label_key='scores'):

return item

def mol_graph_transform(item, atom_key='atoms', label_key='scores', use_bonds=False):
def mol_graph_transform(item, atom_key='atoms', label_key='scores', use_bonds=False, onehot_edges=False):
"""Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset <atom3d.datasets.datasets>`.
Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments.
Expand All @@ -67,7 +67,7 @@ def mol_graph_transform(item, atom_key='atoms', label_key='scores', use_bonds=Fa
bonds = item['bonds']
else:
bonds = None
node_feats, edge_index, edge_feats, pos = gr.mol_df_to_graph(item[atom_key], bonds=bonds)
node_feats, edge_index, edge_feats, pos = gr.mol_df_to_graph(item[atom_key], bonds=bonds, onehot_edges=onehot_edges)
item[atom_key] = Data(node_feats, edge_index, edge_feats, y=item[label_key], pos=pos)

return item
Expand Down
Loading

0 comments on commit 5701486

Please sign in to comment.