diff --git a/atom3d/datasets/__init__.py b/atom3d/datasets/__init__.py index b6943d4..6cc9cb5 100644 --- a/atom3d/datasets/__init__.py +++ b/atom3d/datasets/__init__.py @@ -1 +1 @@ -from .datasets import LMDBDataset, PDBDataset, SilentDataset, load_dataset, make_lmdb_dataset, get_file_list, extract_coordinates_as_numpy_arrays, download_dataset +from .datasets import LMDBDataset, PDBDataset, PTGDataset, SilentDataset, load_dataset, make_lmdb_dataset, get_file_list, extract_coordinates_as_numpy_arrays, download_dataset diff --git a/atom3d/datasets/datasets.py b/atom3d/datasets/datasets.py index 06a3bd2..4d2722b 100644 --- a/atom3d/datasets/datasets.py +++ b/atom3d/datasets/datasets.py @@ -17,6 +17,8 @@ import numpy as np import pandas as pd from torch.utils.data import Dataset, IterableDataset +import torch_geometric.data as ptg +import torch import atom3d.util.rosetta as ar import atom3d.util.file as fi @@ -261,6 +263,7 @@ def __getitem__(self, index: int): if self._gdb: item['labels'] = data item['freq'] = freq + item['smiles'] = smiles if self._transform: item = self._transform(item) return item @@ -313,6 +316,26 @@ def __getitem__(self, index: int): item = self._transform(item) return item +class PTGDataset(ptg.Dataset): + def __init__(self, root, transform=None, pre_transform=None): + super(PTGDataset, self).__init__(root, transform, pre_transform) + + + @property + def processed_dir(self): + return self.root + + @property + def processed_file_names(self): + return ['data_1.pt'] + + + def len(self): + return len(os.listdir(self.processed_dir)) + + def get(self, idx): + data = torch.load(os.path.join(self.processed_dir, 'data_{}.pt'.format(idx))) + return data def serialize(x, serialization_format): """ diff --git a/atom3d/datasets/smp/prepare_lmdb.py b/atom3d/datasets/smp/prepare_lmdb.py index 2cf4298..b77a936 100644 --- a/atom3d/datasets/smp/prepare_lmdb.py +++ b/atom3d/datasets/smp/prepare_lmdb.py @@ -3,6 +3,11 @@ import os import re import sys +from rdkit import Chem +from rdkit import rdBase +from rdkit.Chem.rdchem import HybridizationType +from rdkit import RDConfig +from rdkit.Chem import ChemicalFeatures import click import numpy as np @@ -13,9 +18,51 @@ import atom3d.util.file as fi import atom3d.util.formats as fo - logger = logging.getLogger(__name__) +fdef_name = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') +factory = ChemicalFeatures.BuildFeatureFactory(fdef_name) + +def _get_rdkit_data(smiles): + mol = Chem.MolFromSmiles(smiles) + mol = Chem.AddHs(mol) + bonds_df = fo.get_bonds_list_from_mol(mol) + type_mapping = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} + + type_idx = [] + atomic_number = [] + acceptor = [] + donor = [] + aromatic = [] + sp = [] + sp2 = [] + sp3 = [] + num_hs = [] + for atom in mol.GetAtoms(): + type_idx.append(type_mapping[atom.GetSymbol()]) + atomic_number.append(atom.GetAtomicNum()) + donor.append(0) + acceptor.append(0) + aromatic.append(1 if atom.GetIsAromatic() else 0) + hybridization = atom.GetHybridization() + sp.append(1 if hybridization == HybridizationType.SP else 0) + sp2.append(1 if hybridization == HybridizationType.SP2 else 0) + sp3.append(1 if hybridization == HybridizationType.SP3 else 0) + num_hs.append(atom.GetTotalNumHs(includeNeighbors=True)) + feats = factory.GetFeaturesForMol(mol) + for j in range(0, len(feats)): + if feats[j].GetFamily() == 'Donor': + node_list = feats[j].GetAtomIds() + for k in node_list: + donor[k] = 1 + elif feats[j].GetFamily() == 'Acceptor': + node_list = feats[j].GetAtomIds() + for k in node_list: + acceptor[k] = 1 + atom_feats = [atomic_number, acceptor, donor, aromatic, sp, sp2, sp3, num_hs] + + return bonds_df, atom_feats + def _add_data_with_subtracted_thermochem_energy(x): """ @@ -40,6 +87,7 @@ def _add_data_with_subtracted_thermochem_energy(x): cv_atom = data[14] - np.sum([c * thchem_en[el][4] for el, c in counts.items()]) # Cv # Append new data x['labels'] += [u0_atom, u_atom, h_atom, g_atom, cv_atom] + x['bonds'], x['atom_feats'] = _get_rdkit_data(x['smiles']) # Delete the file path del x['file_path'] return x @@ -56,6 +104,10 @@ def _write_split_indices(split_txt, lmdb_ds, output_txt): f.write(str('\n'.join([str(i) for i in split_indices]))) return split_indices +def bond_filter(item): + if len(item['bonds']) == 0: + return True + return False @click.command(help='Prepare SMP dataset') @click.argument('input_file_path', type=click.Path()) @@ -78,7 +130,7 @@ def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt): lmdb_path = os.path.join(output_root, 'all') logger.info(f'Creating lmdb dataset into {lmdb_path:}...') dataset = da.load_dataset(file_list, filetype, transform=_add_data_with_subtracted_thermochem_energy) - da.make_lmdb_dataset(dataset, lmdb_path) + da.make_lmdb_dataset(dataset, lmdb_path, filter_fn=bond_filter) # Only continue if we want to write split datasets if not split: return @@ -91,9 +143,9 @@ def prepare(input_file_path, output_root, split, train_txt, val_txt, test_txt): indices_test = _write_split_indices(test_txt, lmdb_ds, os.path.join(output_root, 'test_indices.txt')) # Write the split datasets train_dataset, val_dataset, test_dataset = spl.split(lmdb_ds, indices_train, indices_val, indices_test) - da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train')) - da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val')) - da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test')) + da.make_lmdb_dataset(train_dataset, os.path.join(output_root, 'train'), filter_fn=bond_filter) + da.make_lmdb_dataset(val_dataset, os.path.join(output_root, 'val'), filter_fn=bond_filter) + da.make_lmdb_dataset(test_dataset, os.path.join(output_root, 'test'), filter_fn=bond_filter) if __name__ == "__main__": diff --git a/atom3d/datasets/smp/prepare_lmdb.sh b/atom3d/datasets/smp/prepare_lmdb.sh index 9e74bce..863a749 100755 --- a/atom3d/datasets/smp/prepare_lmdb.sh +++ b/atom3d/datasets/smp/prepare_lmdb.sh @@ -2,12 +2,12 @@ #SBATCH --time=24:00:00 #SBATCH --mem=20G -#SBATCH --partition=rondror -#SBATCH --qos=high_p +#SBATCH --partition=rbaltman,owners +# # SBATCH --qos=high_p # Directory definitions -OUT_DIR=/oak/stanford/groups/rondror/projects/atom3d/lmdb/small_molecule_properties +OUT_DIR=/scratch/users/aderry/lmdb/atom3d/small_molecule_properties XYZ_DIR=$SCRATCH/dsgdb9nsd START_DIR=$(pwd) @@ -36,6 +36,6 @@ python prepare_lmdb.py $XYZ_DIR $OUT_DIR --split \ --val_txt splits/ids_validation.txt \ --test_txt splits/ids_test.txt # Remove the raw data -rm $SCRATCH/dsgdb9nsd.xyz.tar.bz2 -rm -r $XYZ_DIR +# rm $SCRATCH/dsgdb9nsd.xyz.tar.bz2 +# rm -r $XYZ_DIR diff --git a/atom3d/util/graph.py b/atom3d/util/graph.py index b228b00..fbc0f8c 100644 --- a/atom3d/util/graph.py +++ b/atom3d/util/graph.py @@ -1,7 +1,9 @@ import numpy as np import scipy.spatial as ss import torch +import torch.nn.functional as F from torch_geometric.utils import to_undirected +from torch_sparse import coalesce import atom3d.util.formats as fo @@ -58,7 +60,7 @@ def prot_df_to_graph(df, feat_col='element', allowable_feats=prot_atoms, edge_di return node_feats, edges, edge_weights, node_pos -def mol_df_to_graph(df, bonds=None, allowable_atoms=mol_atoms, edge_dist_cutoff=4.5): +def mol_df_to_graph(df, bonds=None, allowable_atoms=mol_atoms, edge_dist_cutoff=4.5, onehot_edges=True): """ Converts molecule in dataframe to a graph compatible with Pytorch-Geometric @@ -71,27 +73,33 @@ def mol_df_to_graph(df, bonds=None, allowable_atoms=mol_atoms, edge_dist_cutoff= :return: Tuple containing \n - node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``. - - edges (torch.LongTensor): Edges from chemical bond graph in COO format. + - edge_index (torch.LongTensor): Edges from chemical bond graph in COO format. - edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5. - node_pos (torch.FloatTensor): x-y-z coordinates of each node. """ node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy()) + N = df.shape[0] + bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3} if bonds is not None: bond_data = torch.FloatTensor(bonds.to_numpy()) edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0) - edges = edge_tuples.t().long().contiguous() - edge_feats = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0) + edge_index = edge_tuples.t().long().contiguous() + if onehot_edges: + bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float) + edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N) + else: + edge_attr = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0) else: kd_tree = ss.KDTree(node_pos) edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff)) - edges = torch.LongTensor(edge_tuples).t().contiguous() - edges = to_undirected(edges) - edge_feats = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1) - + edge_index = torch.LongTensor(edge_tuples).t().contiguous() + edge_index = to_undirected(edge_index) + edge_attr = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1) node_feats = torch.FloatTensor([one_of_k_encoding_unk(e, allowable_atoms) for e in df['element']]) - return node_feats, edges, edge_feats, node_pos + return node_feats, edge_index, edge_attr, node_pos def combine_graphs(graph1, graph2, edges_between=True, edges_between_dist=4.5): @@ -158,7 +166,10 @@ def edges_between_graphs(pos1, pos2, dist=4.5): continue for j in contacts: edges.append((i, j + pos1.shape[0])) - edge_weights.append(np.linalg.norm(pos1[i] - pos2[j])) + edges.append((j + pos1.shape[0], i)) + d = 1.0 / (np.linalg.norm(pos1[i] - pos2[j]) + 1e-5) + edge_weights.append(d) + edge_weights.append(d) edges = torch.LongTensor(edges).t().contiguous() edge_weights = torch.FloatTensor(edge_weights).view(-1) diff --git a/atom3d/util/metrics.py b/atom3d/util/metrics.py index d863056..b86cf9a 100644 --- a/atom3d/util/metrics.py +++ b/atom3d/util/metrics.py @@ -83,3 +83,42 @@ def evaluate_average(results, metric=r2, verbose=True, select=None): print(' Validation: %7.3f +/- %7.3f'%summary_va) print(' Test: %7.3f +/- %7.3f'%summary_te) return summary_tr, summary_va, summary_te + +def _per_target_mean(res, metric): + all_res = [] + for targets, predictions in res: + all_res.append(metric(targets, predictions)) + return np.mean(all_res) + +def evaluate_per_target_average(results, metric=r2, verbose=True): + """ + Calculate metric for training, validation and test data, averaged over all replicates. + """ + # Initialization + reps = results.keys() + metric_tr = np.empty(len(reps)) + metric_va = np.empty(len(reps)) + metric_te = np.empty(len(reps)) + # Go through training repetitions + for r, rep in enumerate(results.keys()): + # Load the predictions + train = results[rep]['train'] + val = results[rep]['valid'] + test = results[rep]['test'] + + # Calculate Statistics + metric_tr[r] = _per_target_mean(train, metric) + metric_va[r] = _per_target_mean(val, metric) + metric_te[r] = _per_target_mean(test, metric) + + if verbose: print(' - %s - Training: %7.3f - Validation: %7.3f - Test: %7.3f'%(rep, metric_tr[r], metric_va[r], metric_te[r])) + # Mean and corresponding standard deviations + summary_tr = (np.mean(metric_tr), np.std(metric_tr)) + summary_va = (np.mean(metric_va), np.std(metric_va)) + summary_te = (np.mean(metric_te), np.std(metric_te)) + if verbose: + print('---') + print(' Training: %7.3f +/- %7.3f'%summary_tr) + print(' Validation: %7.3f +/- %7.3f'%summary_va) + print(' Test: %7.3f +/- %7.3f'%summary_te) + return summary_tr, summary_va, summary_te diff --git a/atom3d/util/results.py b/atom3d/util/results.py index 6998c84..c5ee237 100644 --- a/atom3d/util/results.py +++ b/atom3d/util/results.py @@ -2,6 +2,7 @@ import pickle import torch import numpy as np +import pandas as pd import scipy as sp import scipy.stats as stats @@ -44,27 +45,55 @@ def __init__(self, name, reps=[1,2,3]): def get_prediction(self, prediction_fn): """ - Reads targets and prediction. - - TODO: Implement this! - + Reads targets and prediction """ - targets, predict = None, None + pr_data = torch.load(prediction_fn) + targets = np.array( pr_data['targets'] ) + predict = np.array( pr_data['predictions'] ) return targets, predict def get_all_predictions(self): results = {} for r, rep in enumerate(self.reps): prediction_fn = self.name + '-rep'+str(int(rep))+'.best' - # Load the Cormorant predictions targets_tr, predict_tr = self.get_prediction(prediction_fn+'.train.pt') - targets_va, predict_va = self.get_prediction(prediction_fn+'.valid.pt') + targets_va, predict_va = self.get_prediction(prediction_fn+'.val.pt') targets_te, predict_te = self.get_prediction(prediction_fn+'.test.pt') targets = {'train':targets_tr, 'valid':targets_va, 'test':targets_te} predict = {'train':predict_tr, 'valid':predict_va, 'test':predict_te} results['rep'+str(int(rep))] = {'targets':targets, 'predict':predict} return results - + + def get_predictions_by_target(self, prediction_fn): + results_df = pd.DataFrame(torch.load(prediction_fn)) + per_target = [] + for key, val in results_df.groupby(['target']): + # Ignore target with 2 decoys only since the correlations are + # not really meaningful. + if val.shape[0] < 3: + continue + true = val['true'].astype(float).to_numpy() + pred = val['pred'].astype(float).to_numpy() + per_target.append((true, pred)) + global_true = results_df['true'].astype(float).to_numpy() + global_pred = results_df['pred'].astype(float).to_numpy() + return global_true, global_pred, per_target + + def get_target_specific_predictions(self): + """For use with PSR/RSR. Here `target` refers to the protein target, not the prediction target.""" + results = {'global':{}, 'per_target':{}} + for r, rep in enumerate(self.reps): + prediction_fn = self.name + '-rep'+str(int(rep))+'.best' + targets_tr, predict_tr, per_target_tr = self.get_predictions_by_target(prediction_fn+'.train.pt') + targets_va, predict_va, per_target_va = self.get_predictions_by_target(prediction_fn+'.val.pt') + targets_te, predict_te, per_target_te = self.get_predictions_by_target(prediction_fn+'.test.pt') + targets = {'train':targets_tr, 'valid':targets_va, 'test':targets_te} + predict = {'train':predict_tr, 'valid':predict_va, 'test':predict_te} + per_target = {'train': per_target_tr, 'valid':per_target_va, 'test':per_target_te} + results['global']['rep'+str(int(rep))] = {'targets':targets, 'predict':predict} + results['per_target']['rep'+str(int(rep))] = per_target + return results + class ResultsENN(): diff --git a/atom3d/util/transforms.py b/atom3d/util/transforms.py index 5a9dfbc..2100382 100644 --- a/atom3d/util/transforms.py +++ b/atom3d/util/transforms.py @@ -46,7 +46,7 @@ def prot_graph_transform(item, atom_keys=['atoms'], label_key='scores'): return item -def mol_graph_transform(item, atom_key='atoms', label_key='scores', use_bonds=False): +def mol_graph_transform(item, atom_key='atoms', label_key='scores', use_bonds=False, onehot_edges=False): """Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset `. Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments. @@ -67,7 +67,7 @@ def mol_graph_transform(item, atom_key='atoms', label_key='scores', use_bonds=Fa bonds = item['bonds'] else: bonds = None - node_feats, edge_index, edge_feats, pos = gr.mol_df_to_graph(item[atom_key], bonds=bonds) + node_feats, edge_index, edge_feats, pos = gr.mol_df_to_graph(item[atom_key], bonds=bonds, onehot_edges=onehot_edges) item[atom_key] = Data(node_feats, edge_index, edge_feats, y=item[label_key], pos=pos) return item diff --git a/examples/lba/gnn/data.py b/examples/lba/gnn/data.py index cea1191..21e6bfa 100644 --- a/examples/lba/gnn/data.py +++ b/examples/lba/gnn/data.py @@ -1,5 +1,8 @@ import numpy as np import os +import sys +from tqdm import tqdm +import torch from atom3d.util.transforms import prot_graph_transform, mol_graph_transform from atom3d.datasets import LMDBDataset from torch_geometric.data import Data, Dataset, DataLoader @@ -17,7 +20,7 @@ def __call__(self, item): else: item = prot_graph_transform(item, atom_keys=['atoms_protein', 'atoms_pocket'], label_key='scores') # transform ligand into PTG graph - item = mol_graph_transform(item, 'atoms_ligand', 'scores', use_bonds=True) + item = mol_graph_transform(item, 'atoms_ligand', 'scores', use_bonds=True, onehot_edges=False) node_feats, edges, edge_feats, node_pos = gr.combine_graphs(item['atoms_pocket'], item['atoms_ligand']) combined_graph = Data(node_feats, edges, edge_feats, y=item['scores']['neglog_aff'], pos=node_pos) return combined_graph @@ -25,10 +28,24 @@ def __call__(self, item): if __name__=="__main__": - dataset = LMDBDataset('/scratch/users/aderry/lmdb/atom3d/lba_lmdb/splits/split-by-sequence-identity-30/data/train', transform=GNNTransformLBA()) - dataloader = DataLoader(dataset, batch_size=1, shuffle=False) - # for item in dataset[0]: - # print(item, type(dataset[0][item])) - for item in dataloader: - print(item) - break \ No newline at end of file + seqid = sys.argv[1] + save_dir = '/scratch/users/aderry/atom3d/lba_' + str(seqid) + data_dir = f'/scratch/users/raphtown/atom3d_mirror/lmdb/LBA/splits/split-by-sequence-identity-{seqid}/data' + os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) + train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=GNNTransformLBA()) + val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=GNNTransformLBA()) + test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=GNNTransformLBA()) + + print('processing train dataset...') + for i, item in enumerate(tqdm(train_dataset)): + torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) + + print('processing validation dataset...') + for i, item in enumerate(tqdm(val_dataset)): + torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt')) + + print('processing test dataset...') + for i, item in enumerate(tqdm(test_dataset)): + torch.save(item, os.path.join(save_dir, 'test', f'data_{i}.pt')) diff --git a/examples/lba/gnn/evaluate.py b/examples/lba/gnn/evaluate.py new file mode 100644 index 0000000..48b2983 --- /dev/null +++ b/examples/lba/gnn/evaluate.py @@ -0,0 +1,24 @@ +import sys +import numpy as np +import torch +import atom3d.util.results as res +import atom3d.util.metrics as met + +seqid = sys.argv[1] + +# Define the training run +name = f'logs/lba_test_{seqid}/lba' +print(name) + +# Load training results +rloader = res.ResultsGNN(name, reps=[0,1,2]) +results = rloader.get_all_predictions() + +# Calculate and print results +summary = met.evaluate_average(results, metric = met.rmse, verbose = False) +print('Test RMSE: %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_average(results, metric = met.spearman, verbose = False) +print('Test Spearman: %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_average(results, metric = met.pearson, verbose = False) +print('Test Pearson: %6.3f \pm %6.3f'%summary[2]) + diff --git a/examples/lba/gnn/train.py b/examples/lba/gnn/train.py index aef89b6..daf5cbf 100644 --- a/examples/lba/gnn/train.py +++ b/examples/lba/gnn/train.py @@ -12,7 +12,7 @@ from torch_geometric.data import DataLoader from model import GNN_LBA from data import GNNTransformLBA -from atom3d.datasets import LMDBDataset +from atom3d.datasets import LMDBDataset, PTGDataset from scipy.stats import spearmanr def train_loop(model, loader, optimizer, device): @@ -66,13 +66,19 @@ def plot_corr(y_true, y_pred, plot_dir): def save_weights(model, weight_dir): torch.save(model.state_dict(), weight_dir) -def train(args, device, log_dir, seed=None, test_mode=False): +def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) - train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformLBA()) - val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformLBA()) - test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformLBA()) + if args.precomputed: + train_dataset = PTGDataset(os.path.join(args.data_dir, 'train')) + val_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + test_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + else: + transform=GNNTransformLBA() + train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=transform) + val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=transform) + test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=transform) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) @@ -97,7 +103,12 @@ def train(args, device, log_dir, seed=None, test_mode=False): train_loss = train_loop(model, train_loader, optimizer, device) val_loss, r_p, r_s, y_true, y_pred = test(model, val_loader, device) if val_loss < best_val_loss: - save_weights(model, os.path.join(log_dir, f'best_weights.pt')) + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': train_loss, + }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}.png')) best_val_loss = val_loss best_rp = r_p @@ -108,13 +119,18 @@ def train(args, device, log_dir, seed=None, test_mode=False): # logger.info('{:03d}\t{:.7f}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(epoch, train_loss, val_loss, r_p, r_s)) if test_mode: - test_file = os.path.join(log_dir, f'test_results.txt') - model.load_state_dict(torch.load(os.path.join(log_dir, f'best_weights.pt'))) - rmse, pearson, spearman, y_true, y_pred = test(model, test_loader, device) - # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png')) - print('Test RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format(rmse, pearson, spearman)) - with open(test_file, 'a+') as out: - out.write('{}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(seed, rmse, pearson, spearman)) + train_file = os.path.join(log_dir, f'lba-rep{rep}.best.train.pt') + val_file = os.path.join(log_dir, f'lba-rep{rep}.best.val.pt') + test_file = os.path.join(log_dir, f'lba-rep{rep}.best.test.pt') + cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) + model.load_state_dict(cpt['model_state_dict']) + _, _, _, y_true_train, y_pred_train = test(model, train_loader, device) + torch.save({'targets':y_true_train, 'predictions':y_pred_train}, train_file) + _, _, _, y_true_val, y_pred_val = test(model, val_loader, device) + torch.save({'targets':y_true_val, 'predictions':y_pred_val}, val_file) + rmse, pearson, spearman, y_true_test, y_pred_test = test(model, test_loader, device) + print(f'\tTest RMSE {rmse}, Test Pearson {pearson}, Test Spearman {spearman}') + torch.save({'targets':y_true_test, 'predictions':y_pred_test}, test_file) @@ -127,9 +143,11 @@ def train(args, device, log_dir, seed=None, test_mode=False): parser.add_argument('--mode', type=str, default='train') parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--hidden_dim', type=int, default=64) - parser.add_argument('--num_epochs', type=int, default=50) + parser.add_argument('--num_epochs', type=int, default=100) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--log_dir', type=str, default=None) + parser.add_argument('--seqid', type=int, default=30) + parser.add_argument('--precomputed', action='store_true') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -147,11 +165,11 @@ def train(args, device, log_dir, seed=None, test_mode=False): train(args, device, log_dir) elif args.mode == 'test': - for seed in np.random.randint(0, 1000, size=3): + for rep, seed in enumerate(np.random.randint(0, 1000, size=3)): print('seed:', seed) - log_dir = os.path.join('logs', f'test_{seed}') + log_dir = os.path.join('logs', f'lba_test_{args.seqid}') if not os.path.exists(log_dir): os.makedirs(log_dir) np.random.seed(seed) torch.manual_seed(seed) - train(args, device, log_dir, seed, test_mode=True) \ No newline at end of file + train(args, device, log_dir, rep, test_mode=True) diff --git a/examples/lep/gnn/data.py b/examples/lep/gnn/data.py index 4faa0ca..5caa24b 100644 --- a/examples/lep/gnn/data.py +++ b/examples/lep/gnn/data.py @@ -1,8 +1,10 @@ import numpy as np import os +import torch +from tqdm import tqdm from atom3d.util.transforms import prot_graph_transform, PairedGraphTransform from atom3d.datasets import LMDBDataset -from torch_geometric.data import Data, Dataset, DataLoader +from torch_geometric.data import Data, Batch, DataLoader import atom3d.util.graph as gr @@ -16,16 +18,39 @@ def __call__(self, item): item = prot_graph_transform(item, atom_keys=self.atom_keys, label_key=self.label_key) return item + +class CollaterLEP(object): + """To be used with pre-computed graphs and atom3d.datasets.PTGDataset""" + def __init__(self): + pass + def __call__(self, data_list): + batch_1 = Batch.from_data_list([d[0] for d in data_list]) + batch_2 = Batch.from_data_list([d[1] for d in data_list]) + return batch_1, batch_2 if __name__=="__main__": - dataset = LMDBDataset(os.path.join('/scratch/users/raphtown/atom3d_mirror/lmdb/LEP/splits/split-by-protein/data', 'train'), transform=PairedGraphTransform('atoms_active', 'atoms_inactive', label_key='label')) - dataloader = DataLoader(dataset, batch_size=4, shuffle=False) - for active, inactive in dataloader: - print(active) - print(inactive) - break - # for item in dataloader: - # print(item) - # break \ No newline at end of file + save_dir = '/scratch/users/aderry/atom3d/lep' + data_dir = '/scratch/users/raphtown/atom3d_mirror/lmdb/LEP/splits/split-by-protein/data' + os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) + transform = PairedGraphTransform('atoms_active', 'atoms_inactive', label_key='label') + train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=transform) + val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=transform) + test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=transform) + + # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4) + # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4) + # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4) + # for item in dataset[0]: + # print(item, type(dataset[0][item])) + for i, item in enumerate(tqdm(train_dataset)): + torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) + + for i, item in enumerate(tqdm(val_dataset)): + torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt')) + + for i, item in enumerate(tqdm(test_dataset)): + torch.save(item, os.path.join(save_dir, 'test', f'data_{i}.pt')) \ No newline at end of file diff --git a/examples/lep/gnn/evaluate.py b/examples/lep/gnn/evaluate.py new file mode 100644 index 0000000..76caf82 --- /dev/null +++ b/examples/lep/gnn/evaluate.py @@ -0,0 +1,19 @@ +import numpy as np +import torch +import atom3d.util.results as res +import atom3d.util.metrics as met + +# Define the training run +name = 'logs/lep_test/lep' +print(name) + +# Load training results +rloader = res.ResultsGNN(name, reps=[0,1,2]) +results = rloader.get_all_predictions() + +# Calculate and print results +summary = met.evaluate_average(results, metric = met.auroc, verbose = False) +print('Test AUROC: %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_average(results, metric = met.auprc, verbose = False) +print('Test AUPRC: %6.3f \pm %6.3f'%summary[2]) + diff --git a/examples/lep/gnn/train.py b/examples/lep/gnn/train.py index 15e1cca..96ddc6c 100644 --- a/examples/lep/gnn/train.py +++ b/examples/lep/gnn/train.py @@ -11,10 +11,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.data import DataLoader +from torch_geometric.data import DataLoader as PTGDataLoader +from torch.utils.data import DataLoader from model import GNN_LEP, MLP_LEP +from data import CollaterLEP from atom3d.util.transforms import PairedGraphTransform -from atom3d.datasets import LMDBDataset +from atom3d.datasets import LMDBDataset, PTGDataset from scipy.stats import spearmanr from sklearn.metrics import roc_auc_score, average_precision_score @@ -75,7 +77,7 @@ def test(gcn_model, ff_model, loader, criterion, device): # total += active.num_graphs losses.append(loss.item()) y_true.extend(labels.tolist()) - y_pred.extend(output.tolist()) + y_pred.extend(torch.sigmoid(output).tolist()) if it % print_frequency == 0: print(f'iter {it}, loss {np.mean(losses)}') @@ -100,13 +102,21 @@ def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) transform = PairedGraphTransform('atoms_active', 'atoms_inactive', label_key='label') - train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=transform) - val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=transform) - test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=transform) + if args.precomputed: + train_dataset = PTGDataset(os.path.join(args.data_dir, 'train')) + val_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + test_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4, collate_fn=CollaterLEP()) + val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4, collate_fn=CollaterLEP()) + test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, collate_fn=CollaterLEP()) + else: + train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=transform) + val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=transform) + test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=transform) - train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) - val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) - test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4) + train_loader = PTGDataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) + val_loader = PTGDataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) + test_loader = PTGDataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4) for active, inactive in train_loader: num_features = active.num_features @@ -138,21 +148,26 @@ def train(args, device, log_dir, rep=None, test_mode=False): 'ff_state_dict': ff_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, - }, os.path.join(log_dir, f'best_weights.pt')) + }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) best_val_auroc = auroc elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) print(f'\tTrain loss {train_loss}, Val loss {val_loss}, Val AUROC {auroc}, Val auprc {auprc}') if test_mode: - test_file = os.path.join(log_dir, f'lep_rep{rep}.csv') - cpt = torch.load(os.path.join(log_dir, f'best_weights.pt')) + train_file = os.path.join(log_dir, f'lep-rep{rep}.best.train.pt') + val_file = os.path.join(log_dir, f'lep-rep{rep}.best.val.pt') + test_file = os.path.join(log_dir, f'lep-rep{rep}.best.test.pt') + cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) gcn_model.load_state_dict(cpt['gcn_state_dict']) ff_model.load_state_dict(cpt['ff_state_dict']) - test_loss, auroc, auprc, y_true, y_pred = test(gcn_model, ff_model, test_loader, criterion, device) + _, _, _, y_true_train, y_pred_train = test(gcn_model, ff_model, train_loader, criterion, device) + torch.save({'targets':y_true_train, 'predictions':y_pred_train}, train_file) + _, _, _, y_true_val, y_pred_val = test(gcn_model, ff_model, val_loader, criterion, device) + torch.save({'targets':y_true_val, 'predictions':y_pred_val}, val_file) + test_loss, auroc, auprc, y_true_test, y_pred_test = test(gcn_model, ff_model, test_loader, criterion, device) print(f'\tTest loss {test_loss}, Test AUROC {auroc}, Test auprc {auprc}') - res_df = pd.DataFrame(y_true, y_pred, columns=['true', 'pred']) - res_df.to_csv(test_file, index=False) + torch.save({'targets':y_true_test, 'predictions':y_pred_test}, test_file) return test_loss, auroc, auprc @@ -167,15 +182,15 @@ def train(args, device, log_dir, rep=None, test_mode=False): parser.add_argument('--mode', type=str, default='train') parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--hidden_dim', type=int, default=64) - parser.add_argument('--num_epochs', type=int, default=50) + parser.add_argument('--num_epochs', type=int, default=20) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--log_dir', type=str, default=None) + parser.add_argument('--precomputed', action='store_true') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') log_dir = args.log_dir - if args.mode == 'train': if log_dir is None: now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") @@ -189,9 +204,9 @@ def train(args, device, log_dir, rep=None, test_mode=False): elif args.mode == 'test': for rep, seed in enumerate(np.random.randint(0, 1000, size=3)): print('seed:', seed) - log_dir = os.path.join('logs', f'test_rep{rep}') + log_dir = os.path.join('logs', f'lep_test') if not os.path.exists(log_dir): os.makedirs(log_dir) np.random.seed(seed) torch.manual_seed(seed) - train(args, device, log_dir, seed, test_mode=True) \ No newline at end of file + train(args, device, log_dir, rep, test_mode=True) diff --git a/examples/msp/gnn/data.py b/examples/msp/gnn/data.py index dd6cf08..eaf132e 100644 --- a/examples/msp/gnn/data.py +++ b/examples/msp/gnn/data.py @@ -1,6 +1,7 @@ import numpy as np import os import torch +from tqdm import tqdm from atom3d.util.transforms import prot_graph_transform, PairedGraphTransform from atom3d.datasets import LMDBDataset from torch_geometric.data import Data, Dataset, Batch @@ -68,18 +69,30 @@ def __call__(self, batch): if __name__=="__main__": from tqdm import tqdm - # dataset = LMDBDataset(os.path.join('/scratch/users/raphtown/atom3d_mirror/lmdb/MSP/splits/split-by-sequence-identity-30/data', 'train')) - # dataloader = DataLoader(dataset, batch_size=3, shuffle=False, num_workers=4) - # for i, item in tqdm(enumerate(dataloader)): - # if i < 578: - # continue - # print(item) - + + save_dir = '/scratch/users/aderry/atom3d/msp' + data_dir = '/scratch/users/raphtown/atom3d_mirror/lmdb/MSP/splits/split-by-sequence-identity-30/data' + os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) + transform = GNNTransformMSP() + train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=transform) + val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=transform) + test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=transform) + + # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4, collate_fn=CollaterMSP(batch_size=1)) + # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4, collate_fn=CollaterMSP(batch_size=1)) + # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4, collate_fn=CollaterMSP(batch_size=1)) + # for item in dataset[0]: + # print(item, type(dataset[0][item])) + + # for i, item in enumerate(tqdm(train_dataset)): + # torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) + + # print('processing validation dataset...') + # for i, item in enumerate(tqdm(val_dataset)): + # torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt')) - dataset = LMDBDataset(os.path.join('/scratch/users/raphtown/atom3d_mirror/lmdb/MSP/splits/split-by-sequence-identity-30/data', 'train'), transform=GNNTransformMSP()) - dataloader = DataLoader(dataset, batch_size=3, shuffle=False, collate_fn=CollaterMSP(batch_size=3), num_workers=4) - for original, mutated in tqdm(dataloader): - if mutated.mut_idx.max() > mutated.batch.shape[0]: - print(mutated.batch.shape) - print(mutated.mut_idx) - break \ No newline at end of file + print('processing test dataset...') + for i, item in enumerate(tqdm(test_dataset)): + torch.save(item, os.path.join(save_dir, 'test', f'data_{i}.pt')) \ No newline at end of file diff --git a/examples/msp/gnn/evaluate.py b/examples/msp/gnn/evaluate.py new file mode 100644 index 0000000..86446ca --- /dev/null +++ b/examples/msp/gnn/evaluate.py @@ -0,0 +1,19 @@ +import numpy as np +import torch +import atom3d.util.results as res +import atom3d.util.metrics as met + +# Define the training run +name = 'logs/msp_test/msp' +print(name) + +# Load training results +rloader = res.ResultsGNN(name, reps=[0,1,2]) +results = rloader.get_all_predictions() + +# Calculate and print results +summary = met.evaluate_average(results, metric = met.auroc, verbose = False) +print('Test AUROC: %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_average(results, metric = met.auprc, verbose = False) +print('Test AUPRC: %6.3f \pm %6.3f'%summary[2]) + diff --git a/examples/msp/gnn/train.py b/examples/msp/gnn/train.py index 75f47f7..8609d6f 100644 --- a/examples/msp/gnn/train.py +++ b/examples/msp/gnn/train.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from model import GNN_MSP, MLP_MSP from data import GNNTransformMSP, CollaterMSP -from atom3d.datasets import LMDBDataset +from atom3d.datasets import LMDBDataset, PTGDataset from scipy.stats import spearmanr from sklearn.metrics import roc_auc_score, average_precision_score @@ -71,16 +71,16 @@ def test(gcn_model, ff_model, loader, criterion, device): # total += original.num_graphs losses.append(loss.item()) y_true.extend(original.y.tolist()) - y_pred.extend(output.tolist()) - if it % print_frequency == 0: - print(f'iter {it}, loss {np.mean(losses)}') + y_pred.extend(torch.sigmoid(output).tolist()) + # if it % print_frequency == 0: + # print(f'iter {it}, loss {np.mean(losses)}') y_true = np.array(y_true) y_pred = np.array(y_pred) auroc = roc_auc_score(y_true, y_pred) auprc = average_precision_score(y_true, y_pred) - return np.mean(losses), auroc, auprc + return np.mean(losses), auroc, auprc, y_true, y_pred def plot_corr(y_true, y_pred, plot_dir): plt.clf() @@ -92,18 +92,23 @@ def plot_corr(y_true, y_pred, plot_dir): def save_weights(model, weight_dir): torch.save(model.state_dict(), weight_dir) -def train(args, device, log_dir, seed=None, test_mode=False): +def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) transform = GNNTransformMSP() - train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=transform) - val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=transform) - test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=transform) + if args.precomputed: + train_dataset = PTGDataset(os.path.join(args.data_dir, 'train')) + val_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + test_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + else: + train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=transform) + val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=transform) + test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=transform) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4, collate_fn=CollaterMSP(batch_size=args.batch_size)) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4, collate_fn=CollaterMSP(batch_size=args.batch_size)) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, collate_fn=CollaterMSP(batch_size=args.batch_size)) - + for original, mutated in train_loader: num_features = original.num_features break @@ -126,7 +131,7 @@ def train(args, device, log_dir, seed=None, test_mode=False): start = time.time() train_loss = train_loop(epoch, gcn_model, ff_model, train_loader, criterion, optimizer, device) print('validating...') - val_loss, auroc, auprc = test(gcn_model, ff_model, val_loader, criterion, device) + val_loss, auroc, auprc, _, _ = test(gcn_model, ff_model, val_loader, criterion, device) if auroc > best_val_auroc: torch.save({ 'epoch': epoch, @@ -134,22 +139,26 @@ def train(args, device, log_dir, seed=None, test_mode=False): 'ff_state_dict': ff_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, - }, os.path.join(log_dir, f'best_weights.pt')) + }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) best_val_auroc = auroc elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) print(f'\tTrain loss {train_loss}, Val loss {val_loss}, Val AUROC {auroc}, Val auprc {auprc}') if test_mode: - test_file = os.path.join(log_dir, f'test_results.txt') - cpt = torch.load(os.path.join(log_dir, f'best_weights.pt')) + train_file = os.path.join(log_dir, f'msp-rep{rep}.best.train.pt') + val_file = os.path.join(log_dir, f'msp-rep{rep}.best.val.pt') + test_file = os.path.join(log_dir, f'msp-rep{rep}.best.test.pt') + cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) gcn_model.load_state_dict(cpt['gcn_state_dict']) ff_model.load_state_dict(cpt['ff_state_dict']) - test_loss, auroc, auprc = test(gcn_model, ff_model, test_loader, criterion, device) + _, _, _, y_true_train, y_pred_train = test(gcn_model, ff_model, train_loader, criterion, device) + torch.save({'targets':y_true_train, 'predictions':y_pred_train}, train_file) + _, _, _, y_true_val, y_pred_val = test(gcn_model, ff_model, val_loader, criterion, device) + torch.save({'targets':y_true_val, 'predictions':y_pred_val}, val_file) + test_loss, auroc, auprc, y_true_test, y_pred_test = test(gcn_model, ff_model, test_loader, criterion, device) print(f'\tTest loss {test_loss}, Test AUROC {auroc}, Test auprc {auprc}') - with open(test_file, 'w') as f: - f.write(f'test_loss\tAUROC\n') - f.write(f'{test_loss}\t{auroc}\n') + torch.save({'targets':y_true_test, 'predictions':y_pred_test}, test_file) return test_loss, auroc, auprc @@ -166,6 +175,7 @@ def train(args, device, log_dir, seed=None, test_mode=False): parser.add_argument('--num_epochs', type=int, default=50) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--log_dir', type=str, default=None) + parser.add_argument('--precomputed', action='store_true') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -183,11 +193,11 @@ def train(args, device, log_dir, seed=None, test_mode=False): train(args, device, log_dir) elif args.mode == 'test': - for seed in np.random.randint(0, 1000, size=3): + for rep, seed in enumerate(np.random.randint(0, 1000, size=3)): print('seed:', seed) - log_dir = os.path.join('logs', f'test_{seed}') + log_dir = os.path.join('logs', f'msp_test') if not os.path.exists(log_dir): os.makedirs(log_dir) np.random.seed(seed) torch.manual_seed(seed) - train(args, device, log_dir, seed, test_mode=True) \ No newline at end of file + train(args, device, log_dir, rep, test_mode=True) \ No newline at end of file diff --git a/examples/psr/gnn/data.py b/examples/psr/gnn/data.py index a6c122d..1f0e013 100644 --- a/examples/psr/gnn/data.py +++ b/examples/psr/gnn/data.py @@ -1,6 +1,7 @@ import numpy as np import os import torch +from tqdm import tqdm from atom3d.util.transforms import prot_graph_transform from atom3d.datasets import LMDBDataset from torch_geometric.data import Data, Dataset, DataLoader @@ -13,17 +14,37 @@ def __call__(self, item): item = prot_graph_transform(item, ['atoms'], 'scores') graph = item['atoms'] graph.y = torch.FloatTensor([graph.y['gdt_ts']]) - graph.target = item['id'][0] - graph.decoy = item['id'][1] + split = item['id'].split("'") + graph.target = split[1] + graph.decoy = split[3] return graph if __name__=="__main__": - dataset = LMDBDataset('/scratch/users/raphtown/atom3d_mirror/lmdb/PSR/splits/split-by-year/data/train', transform=GNNTransformPSR()) - dataloader = DataLoader(dataset, batch_size=4, shuffle=False) + save_dir = '/scratch/users/aderry/atom3d/psr' + data_dir = '/scratch/users/raphtown/atom3d_mirror/lmdb/PSR/splits/split-by-year/data' + os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) + train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=GNNTransformPSR()) + val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=GNNTransformPSR()) + test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=GNNTransformPSR()) + + # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4) + # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4) + # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4) # for item in dataset[0]: # print(item, type(dataset[0][item])) - for item in dataloader: - print(item) - break \ No newline at end of file + + # print('processing train dataset...') + # for i, item in enumerate(tqdm(train_dataset)): + # torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) + + # print('processing validation dataset...') + # for i, item in enumerate(tqdm(val_dataset)): + # torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt')) + + print('processing test dataset...') + for i, item in enumerate(tqdm(test_dataset)): + torch.save(item, os.path.join(save_dir, 'test', f'data_{i}.pt')) \ No newline at end of file diff --git a/examples/psr/gnn/evaluate.py b/examples/psr/gnn/evaluate.py new file mode 100644 index 0000000..cbdd42d --- /dev/null +++ b/examples/psr/gnn/evaluate.py @@ -0,0 +1,28 @@ +import numpy as np +import torch +import atom3d.util.results as res +import atom3d.util.metrics as met + +# Define the training run +name = 'logs/psr_test/psr' +print(name) + +# Load training results +rloader = res.ResultsGNN(name, reps=[0,1,2]) +results = rloader.get_target_specific_predictions() + +# Calculate and print results +summary = met.evaluate_per_target_average(results['per_target'], metric = met.spearman, verbose = False) +print('Test Spearman (per-target): %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_per_target_average(results['per_target'], metric = met.pearson, verbose = False) +print('Test Pearson (per-target): %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_per_target_average(results['per_target'], metric = met.kendall, verbose = False) +print('Test Kendall (per-target): %6.3f \pm %6.3f'%summary[2]) + +summary = met.evaluate_average(results['global'], metric = met.spearman, verbose = False) +print('Test Spearman (global): %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_average(results['global'], metric = met.pearson, verbose = False) +print('Test Pearson (global): %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_average(results['global'], metric = met.kendall, verbose = False) +print('Test Kendall (global): %6.3f \pm %6.3f'%summary[2]) + diff --git a/examples/psr/gnn/model.py b/examples/psr/gnn/model.py index 171b103..dba85a3 100644 --- a/examples/psr/gnn/model.py +++ b/examples/psr/gnn/model.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GCNConv, global_mean_pool +from torch_geometric.nn import GCNConv, global_add_pool class GNN_PSR(torch.nn.Module): def __init__(self, num_features, hidden_dim): @@ -35,7 +35,7 @@ def forward(self, x, edge_index, edge_weight, batch): x = F.relu(x) x = self.conv5(x, edge_index, edge_weight) x = self.bn5(x) - x = global_mean_pool(x, batch) + x = global_add_pool(x, batch) x = F.relu(x) x = F.relu(self.fc1(x)) x = F.dropout(x, p=0.25, training=self.training) diff --git a/examples/psr/gnn/train.py b/examples/psr/gnn/train.py index b35c9fd..97a13be 100644 --- a/examples/psr/gnn/train.py +++ b/examples/psr/gnn/train.py @@ -3,7 +3,6 @@ import os import time import datetime -import wandb import matplotlib.pyplot as plt import numpy as np @@ -14,7 +13,7 @@ from torch_geometric.data import DataLoader from model import GNN_PSR from data import GNNTransformPSR -from atom3d.datasets import LMDBDataset +from atom3d.datasets import LMDBDataset, PTGDataset import atom3d.datasets.psr.util as psr_util def compute_correlations(results): @@ -71,12 +70,11 @@ def train_loop(model, loader, optimizer, device): loss_all += loss.item() * data.num_graphs total += data.num_graphs optimizer.step() - wandb.log({'train_loss': loss}) return np.sqrt(loss_all / total) @torch.no_grad() -def test(model, loader, device, log=True): +def test(model, loader, device): model.eval() losses = [] @@ -109,8 +107,6 @@ def test(model, loader, device, log=True): ) res = compute_correlations(test_df) - if log: - wandb.log({'val_loss': np.mean(losses), 'pearson': res['all_pearson'], 'kendall': res['all_kendall'], 'spearman': res['all_spearman']}) return np.mean(losses), res, test_df @@ -124,13 +120,19 @@ def plot_corr(y_true, y_pred, plot_dir): def save_weights(model, weight_dir): torch.save(model.state_dict(), weight_dir) -def train(args, device, log_dir, seed=None, test_mode=False): +def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) - train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformPSR()) - val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformPSR()) - test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformPSR()) + if args.precomputed: + train_dataset = PTGDataset(os.path.join(args.data_dir, 'train')) + val_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + test_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + + else: + train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformPSR()) + val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformPSR()) + test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformPSR()) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) @@ -149,23 +151,21 @@ def train(args, device, log_dir, seed=None, test_mode=False): optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', - factor=0.7, patience=3, - min_lr=0.00001) + # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', + # factor=0.7, patience=3, + # min_lr=0.00001) for epoch in range(1, args.num_epochs+1): start = time.time() train_loss = train_loop(model, train_loader, optimizer, device) - print('validating...') - val_loss, corrs, test_df = test(model, val_loader, device) - scheduler.step(val_loss) + val_loss, corrs, results_df = test(model, val_loader, device) if corrs['all_spearman'] > best_rs: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, - }, os.path.join(log_dir, f'best_weights.pt')) + }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) best_rs = corrs['all_spearman'] elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) @@ -173,13 +173,19 @@ def train(args, device, log_dir, seed=None, test_mode=False): train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman'])) if test_mode: - test_file = os.path.join(log_dir, f'psr_rep{rep}.csv') - model.load_state_dict(torch.load(os.path.join(log_dir, f'best_weights.pt'))) - val_loss, corrs, results_df = test(model, test_loader, device, log=False) - # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png')) - print('\tTest RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'.format( - train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman'])) - pd.to_csv(results_df, test_file, index=False) + cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) + model.load_state_dict(cpt['model_state_dict']) + train_file = os.path.join(log_dir, f'psr-rep{rep}.best.train.pt') + val_file = os.path.join(log_dir, f'psr-rep{rep}.best.val.pt') + test_file = os.path.join(log_dir, f'psr-rep{rep}.best.test.pt') + cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) + model.load_state_dict(cpt['model_state_dict']) + _, corrs, results_train = test(model, train_loader, device) + torch.save(results_train.to_dict('list'), train_file) + _, corrs, results_val = test(model, val_loader, device) + torch.save(results_val.to_dict('list'), val_file) + _, corrs, results_test = test(model, test_loader, device) + torch.save(results_test.to_dict('list'), test_file) @@ -187,19 +193,16 @@ def train(args, device, log_dir, seed=None, test_mode=False): parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str) parser.add_argument('--mode', type=str, default='train') - parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--batch_size', type=int, default=40) parser.add_argument('--hidden_dim', type=int, default=64) parser.add_argument('--num_epochs', type=int, default=20) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--log_dir', type=str, default=None) + parser.add_argument('--precomputed', action='store_true') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') log_dir = args.log_dir - - wandb.init(project="atom3d", name='PSR', config=vars(args) - ) - if args.mode == 'train': if log_dir is None: @@ -212,11 +215,11 @@ def train(args, device, log_dir, seed=None, test_mode=False): train(args, device, log_dir) elif args.mode == 'test': - for seed in np.random.randint(0, 1000, size=3): + for rep, seed in enumerate(np.random.randint(0, 1000, size=3)): print('seed:', seed) - log_dir = os.path.join('logs', f'test_{seed}') + log_dir = os.path.join('logs', f'psr_test') if not os.path.exists(log_dir): os.makedirs(log_dir) np.random.seed(seed) torch.manual_seed(seed) - train(args, device, log_dir, seed, test_mode=True) + train(args, device, log_dir, rep, test_mode=True) diff --git a/examples/rsr/gnn/data.py b/examples/rsr/gnn/data.py index 69a3423..d494256 100644 --- a/examples/rsr/gnn/data.py +++ b/examples/rsr/gnn/data.py @@ -1,6 +1,7 @@ import numpy as np import os import torch +from tqdm import tqdm from atom3d.util.transforms import prot_graph_transform from atom3d.datasets import LMDBDataset from torch_geometric.data import Data, Dataset, DataLoader @@ -13,17 +14,36 @@ def __call__(self, item): item = prot_graph_transform(item, ['atoms'], 'scores') graph = item['atoms'] graph.y = torch.FloatTensor([graph.y['rms']]) - graph.target = item['id'][0] - graph.decoy = item['id'][1] + split = item['id'].split("'") + graph.target = split[1] + graph.decoy = split[3] return graph if __name__=="__main__": - dataset = LMDBDataset('/scratch/users/raphtown/atom3d_mirror/lmdb/RSR/splits/candidates-split-by-time/data/train', transform=GNNTransformRSR()) - dataloader = DataLoader(dataset, batch_size=4, shuffle=False) + save_dir = '/scratch/users/aderry/atom3d/rsr' + data_dir = '/scratch/users/raphtown/atom3d_mirror/lmdb/RSR/splits/candidates-split-by-time/data' + os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) + train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=GNNTransformRSR()) + val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=GNNTransformRSR()) + test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=GNNTransformRSR()) + + # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4) + # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4) + # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4) # for item in dataset[0]: # print(item, type(dataset[0][item])) - for item in dataloader: - print(item) - break + print('processing train dataset...') + for i, item in enumerate(tqdm(train_dataset)): + torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) + + print('processing validation dataset...') + for i, item in enumerate(tqdm(val_dataset)): + torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt')) + + print('processing test dataset...') + for i, item in enumerate(tqdm(test_dataset)): + torch.save(item, os.path.join(save_dir, 'test', f'data_{i}.pt')) diff --git a/examples/rsr/gnn/evaluate.py b/examples/rsr/gnn/evaluate.py new file mode 100644 index 0000000..5bb8806 --- /dev/null +++ b/examples/rsr/gnn/evaluate.py @@ -0,0 +1,28 @@ +import numpy as np +import torch +import atom3d.util.results as res +import atom3d.util.metrics as met + +# Define the training run +name = 'logs/rsr_test/rsr' +print(name) + +# Load training results +rloader = res.ResultsGNN(name, reps=[0,1,2]) +results = rloader.get_target_specific_predictions() + +# Calculate and print results +summary = met.evaluate_per_target_average(results['per_target'], metric = met.spearman, verbose = False) +print('Test Spearman (per-target): %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_per_target_average(results['per_target'], metric = met.pearson, verbose = False) +print('Test Pearson (per-target): %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_per_target_average(results['per_target'], metric = met.kendall, verbose = False) +print('Test Kendall (per-target): %6.3f \pm %6.3f'%summary[2]) + +summary = met.evaluate_average(results['global'], metric = met.spearman, verbose = False) +print('Test Spearman (global): %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_average(results['global'], metric = met.pearson, verbose = False) +print('Test Pearson (global): %6.3f \pm %6.3f'%summary[2]) +summary = met.evaluate_average(results['global'], metric = met.kendall, verbose = False) +print('Test Kendall (global): %6.3f \pm %6.3f'%summary[2]) + diff --git a/examples/rsr/gnn/train.py b/examples/rsr/gnn/train.py index c6e1fa9..6576dbf 100644 --- a/examples/rsr/gnn/train.py +++ b/examples/rsr/gnn/train.py @@ -3,7 +3,6 @@ import os import time import datetime -import wandb import matplotlib.pyplot as plt import numpy as np @@ -14,7 +13,7 @@ from torch_geometric.data import DataLoader from model import GNN_RSR from data import GNNTransformRSR -from atom3d.datasets import LMDBDataset +from atom3d.datasets import LMDBDataset, PTGDataset from scipy.stats import spearmanr @@ -73,12 +72,11 @@ def train_loop(model, loader, optimizer, device): loss_all += loss.item() * data.num_graphs total += data.num_graphs optimizer.step() - wandb.log({'train_loss': loss}) return np.sqrt(loss_all / total) @torch.no_grad() -def test(model, loader, device, log=True): +def test(model, loader, device): model.eval() losses = [] @@ -104,8 +102,6 @@ def test(model, loader, device, log=True): ) res = compute_correlations(results_df) - if log: - wandb.log({'val_loss': np.mean(losses), 'pearson': res['all_pearson'], 'kendall': res['all_kendall'], 'spearman': res['all_spearman']}) return np.sqrt(np.mean(losses)), res, results_df @@ -122,10 +118,16 @@ def save_weights(model, weight_dir): def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) - - train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformRSR()) - val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformRSR()) - test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformRSR()) + + if args.precomputed: + train_dataset = PTGDataset(os.path.join(args.data_dir, 'train')) + val_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + test_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) + + else: + train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformRSR()) + val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformRSR()) + test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformRSR()) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) @@ -154,7 +156,7 @@ def train(args, device, log_dir, rep=None, test_mode=False): 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, - }, os.path.join(log_dir, f'best_weights.pt')) + }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) best_rs = corrs['all_spearman'] elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) @@ -162,13 +164,19 @@ def train(args, device, log_dir, rep=None, test_mode=False): train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman'])) if test_mode: - test_file = os.path.join(log_dir, f'rsr_rep{rep}.csv') - model.load_state_dict(torch.load(os.path.join(log_dir, f'best_weights.pt'))) - val_loss, corrs, results_df = test(model, test_loader, device, log=False) - # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png')) - print('\tTest RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'.format( - train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman'])) - pd.to_csv(results_df, test_file, index=False) + cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) + model.load_state_dict(cpt['model_state_dict']) + train_file = os.path.join(log_dir, f'rsr-rep{rep}.best.train.pt') + val_file = os.path.join(log_dir, f'rsr-rep{rep}.best.val.pt') + test_file = os.path.join(log_dir, f'rsr-rep{rep}.best.test.pt') + cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) + model.load_state_dict(cpt['model_state_dict']) + _, corrs, results_train = test(model, train_loader, device) + torch.save(results_train.to_dict('list'), train_file) + _, corrs, results_val = test(model, val_loader, device) + torch.save(results_val.to_dict('list'), val_file) + _, corrs, results_test = test(model, test_loader, device) + torch.save(results_test.to_dict('list'), test_file) @@ -181,13 +189,12 @@ def train(args, device, log_dir, rep=None, test_mode=False): parser.add_argument('--num_epochs', type=int, default=20) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--log_dir', type=str, default=None) + parser.add_argument('--precomputed', action='store_true') args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') log_dir = args.log_dir - wandb.init(project="atom3d", name='RSR', config=vars(args) - ) if args.mode == 'train': if log_dir is None: @@ -202,9 +209,9 @@ def train(args, device, log_dir, rep=None, test_mode=False): elif args.mode == 'test': for rep, seed in enumerate(np.random.randint(0, 1000, size=3)): print('seed:', seed) - log_dir = os.path.join('logs', f'test_rep{rep}') + log_dir = os.path.join('logs', f'rsr_test') if not os.path.exists(log_dir): os.makedirs(log_dir) np.random.seed(seed) torch.manual_seed(seed) - train(args, device, log_dir, seed, test_mode=True) + train(args, device, log_dir, rep, test_mode=True) diff --git a/examples/smp/gnn/data.py b/examples/smp/gnn/data.py index e69de29..af950c3 100644 --- a/examples/smp/gnn/data.py +++ b/examples/smp/gnn/data.py @@ -0,0 +1,57 @@ +import numpy as np +import os +import torch +from tqdm import tqdm +from atom3d.util.transforms import mol_graph_transform +from atom3d.datasets import LMDBDataset +from torch_geometric.data import Data, Dataset, DataLoader + + +class GNNTransformSMP(object): + def __init__(self, label_name): + self.label_name = label_name + def _lookup_label(self, item, name): + if 'label_mapping' not in self.__dict__: + label_mapping = [ + 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', + 'u0', 'u298', 'h298', 'g298', 'cv', + 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'cv_atom', + ] + self.label_mapping = {k: v for v, k in enumerate(label_mapping)} + return item['labels'][self.label_mapping[name]] + + def __call__(self, item): + item = mol_graph_transform(item, 'atoms', 'labels', use_bonds=True, onehot_edges=True) + graph = item['atoms'] + x2 = torch.tensor(item['atom_feats'], dtype=torch.float).t().contiguous() + graph.x = torch.cat([graph.x.to(torch.float), x2], dim=-1) + graph.y = self._lookup_label(item, self.label_name) + graph.id = item['id'] + return graph + + + +if __name__=="__main__": + save_dir = '/scratch/users/aderry/atom3d/smp' + data_dir = '/scratch/users/aderry/lmdb/atom3d/small_molecule_properties' + os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) + os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) + train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=GNNTransformSMP(label_name='mu')) + # val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=GNNTransformSMP()) + # test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=GNNTransformSMP()) + + # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4) + # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4) + # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4) + # for item in dataset[0]: + # print(item, type(dataset[0][item])) + for i, item in enumerate(tqdm(train_dataset)): + break + # torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) + + # for i, item in enumerate(tqdm(val_dataset)): + # torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt')) + + # for i, item in enumerate(tqdm(test_dataset)): + # torch.save(item, os.path.join(save_dir, 'test', f'data_{i}.pt')) \ No newline at end of file diff --git a/examples/smp/gnn/evaluate.py b/examples/smp/gnn/evaluate.py new file mode 100644 index 0000000..86b6357 --- /dev/null +++ b/examples/smp/gnn/evaluate.py @@ -0,0 +1,19 @@ +import numpy as np +import atom3d.util.results as res +import atom3d.util.metrics as met + +labels = np.loadtxt('labels.txt', dtype=str) +conversion = {'A':1.0, 'B':1.0, 'C':1.0, 'mu':1.0, 'alpha':1.0, + 'homo':27.2114, 'lumo':27.2114, 'gap':27.2114, 'r2':1.0, 'zpve':27211.4, + 'u0':27.2114, 'u298':27.2114, 'h298':27.2114, 'g298':27.2114, 'cv':1.0, + 'u0_atom':27.2114, 'u298_atom':27.2114, 'h298_atom':27.2114, 'g298_atom':27.2114, 'cv_atom':1.0} + +for label in labels: + name = f'logs/smp_test_{label}/smp' + cf = conversion[label] + rloader = res.ResultsGNN(name, reps=[0,1,2]) + results = rloader.get_all_predictions() + summary = met.evaluate_average(results, metric = met.mae, verbose = False) + summary = [(cf*s[0],cf*s[1]) for s in summary] + print('%9s: %6.3f \pm %6.3f'%(label, *summary[2])) + diff --git a/examples/smp/gnn/model.py b/examples/smp/gnn/model.py index e69de29..3f927fc 100644 --- a/examples/smp/gnn/model.py +++ b/examples/smp/gnn/model.py @@ -0,0 +1,31 @@ +import torch +from torch.nn import Sequential, Linear, ReLU, GRU +import torch.nn.functional as F +from torch_geometric.nn import NNConv, Set2Set + +class GNN_SMP(torch.nn.Module): + def __init__(self, num_features, dim): + super(GNN_SMP, self).__init__() + self.lin0 = torch.nn.Linear(num_features, dim) + + nn = Sequential(Linear(4, 128), ReLU(), Linear(128, dim * dim)) + self.conv = NNConv(dim, dim, nn, aggr='mean') + self.gru = GRU(dim, dim) + + self.set2set = Set2Set(dim, processing_steps=3) + self.lin1 = torch.nn.Linear(2 * dim, dim) + self.lin2 = torch.nn.Linear(dim, 1) + + def forward(self, data): + out = F.relu(self.lin0(data.x)) + h = out.unsqueeze(0) + + for i in range(3): + m = F.relu(self.conv(out, data.edge_index, data.edge_attr)) + out, h = self.gru(m.unsqueeze(0), h) + out = out.squeeze(0) + + out = self.set2set(out, data.batch) + out = F.relu(self.lin1(out)) + out = self.lin2(out) + return out.view(-1) \ No newline at end of file diff --git a/examples/smp/gnn/train.py b/examples/smp/gnn/train.py index e69de29..1feba8a 100644 --- a/examples/smp/gnn/train.py +++ b/examples/smp/gnn/train.py @@ -0,0 +1,150 @@ +import argparse +import logging +import os +import time +import datetime + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import pandas as pd +import torch +import torch.nn.functional as F +from torch_geometric.data import DataLoader +from model import GNN_SMP +from data import GNNTransformSMP +from atom3d.datasets import LMDBDataset + +def train_loop(model, loader, optimizer, device): + model.train() + + loss_all = 0 + total = 0 + for data in loader: + data = data.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.mse_loss(output, data.y) + loss.backward() + loss_all += loss.item() * data.num_graphs + total += data.num_graphs + optimizer.step() + return loss_all / total + +@torch.no_grad() +def test(model, loader, device): + model.eval() + loss_all = 0 + total = 0 + y_true = [] + y_pred = [] + for data in loader: + data = data.to(device) + output=model(data) + loss = F.l1_loss(output, data.y) # MAE + loss_all += loss.item() * data.num_graphs + total += data.num_graphs + y_true.extend([x.item() for x in data.y]) + y_pred.extend(output.tolist()) + return loss_all / total, y_true, y_pred + + +def save_weights(model, weight_dir): + torch.save(model.state_dict(), weight_dir) + +def train(args, device, log_dir, rep=None, test_mode=False): + # logger = logging.getLogger('lba') + # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) + + train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformSMP(args.target_name)) + val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformSMP(args.target_name)) + test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformSMP(args.target_name)) + + train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) + test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4) + + for data in train_loader: + num_features = data.num_features + break + + model = GNN_SMP(num_features, dim=args.hidden_dim).to(device) + model.to(device) + + best_val_loss = 999 + + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', + factor=0.7, patience=3, + min_lr=0.00001) + + for epoch in range(1, args.num_epochs+1): + start = time.time() + train_loss = train_loop(model, train_loader, optimizer, device) + print('validating...') + val_loss, _,_ = test(model, val_loader, device) + scheduler.step(val_loss) + if val_loss < best_val_loss: + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': train_loss, + }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) + best_val_loss = val_loss + elapsed = (time.time() - start) + print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) + print('\tTrain Loss: {:.7f}, Val MAE: {:.7f}'.format(train_loss, val_loss)) + + if test_mode: + train_file = os.path.join(log_dir, f'smp-rep{rep}.best.train.pt') + val_file = os.path.join(log_dir, f'smp-rep{rep}.best.val.pt') + test_file = os.path.join(log_dir, f'smp-rep{rep}.best.test.pt') + cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) + model.load_state_dict(cpt['model_state_dict']) + _, y_true_train, y_pred_train = test(model, train_loader, device) + torch.save({'targets':y_true_train, 'predictions':y_pred_train}, train_file) + _, y_true_val, y_pred_val = test(model, val_loader, device) + torch.save({'targets':y_true_val, 'predictions':y_pred_val}, val_file) + mae, y_true_test, y_pred_test = test(model, test_loader, device) + print(f'\tTest MAE {mae}') + torch.save({'targets':y_true_test, 'predictions':y_pred_test}, test_file) + + + +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str) + parser.add_argument('--target_name', type=str) + parser.add_argument('--mode', type=str, default='train') + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--hidden_dim', type=int, default=64) + parser.add_argument('--num_epochs', type=int, default=300) + parser.add_argument('--learning_rate', type=float, default=1e-3) + parser.add_argument('--log_dir', type=str, default=None) + args = parser.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + log_dir = args.log_dir + + + if args.mode == 'train': + if log_dir is None: + now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + log_dir = os.path.join('logs', now) + else: + log_dir = os.path.join('logs', log_dir) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + train(args, device, log_dir) + + elif args.mode == 'test': + for rep, seed in enumerate(np.random.randint(0, 1000, size=3)): + print('seed:', seed) + log_dir = os.path.join('logs', f'smp_test_{args.target_name}') + if not os.path.exists(log_dir): + os.makedirs(log_dir) + np.random.seed(seed) + torch.manual_seed(seed) + train(args, device, log_dir, rep, test_mode=True)