forked from wjm41/soapgp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper.py
136 lines (114 loc) · 4.94 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import random
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Scaffolds import MurckoScaffold
def split_by_lengths(seq, num_list):
"""
Splits the input sequence seq into variably-sized chunks determined by the entries in num_list.
:param seq: a list/array to-be-split
:param num_list: a list/array of positive integers indicating the chunk-size to split seq
:return: a list which consists of seq sliced according to num_list
"""
out_list = []
i=0
for j in num_list:
out_list.append(seq[i:i+j])
i+=j
return out_list
def return_borders(index, dat_len, mpi_size):
"""
A utility function for returning the data indices from partitioning data between MPI processes.
:param index: index of the MPI process
:param dat_len: length of the data array to-be-split
:param mpi_size: number of MPI processes in total
:return: the lower and upper indices indicating the data range that should allocated to a particular MPI process
"""
mpi_borders = np.linspace(0, dat_len, mpi_size + 1).astype('int')
border_low = mpi_borders[index]
border_high = mpi_borders[index+1]
return border_low, border_high
def generate_scaffold(mol, include_chirality):
"""
Compute the Bemis-Murcko scaffold for a SMILES string.
Implementation copied from https://github.com/chemprop/chemprop.
:param mol: A smiles string or an RDKit molecule.
:param include_chirality: Whether to include chirality.
:return:
"""
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
return scaffold
def scaffold_to_smiles(mols, use_indices):
"""
Computes scaffold for each smiles string and returns a mapping from scaffolds to sets of smiles.
Implementation copied from https://github.com/chemprop/chemprop.
:param mols: A list of smiles strings or RDKit molecules.
:param use_indices: Whether to map to the smiles' index in all_smiles rather than mapping
to the smiles string itself. This is necessary if there are duplicate smiles.
:return: A dictionary mapping each unique scaffold to all smiles (or smiles indices) which have that scaffold.
"""
scaffolds = defaultdict(set)
for i, mol in enumerate(mols):
scaffold = generate_scaffold(mol)
if use_indices:
scaffolds[scaffold].add(i)
else:
scaffolds[scaffold].add(mol)
return scaffolds
def scaffold_split(data,
sizes = (0.8, 0.2),
balanced = True,
seed = 0):
"""
Split a dataset by scaffold so that no molecules sharing a scaffold are in the same split.
Implementation copied from https://github.com/chemprop/chemprop.
:param data: List of smiles strings
:param sizes: A length-2 tuple with the proportions of data in the
train and test sets.
:param balanced: Try to balance sizes of scaffolds in each set, rather than just putting smallest in test set.
:param seed: Seed for shuffling when doing balanced splitting.
:return: A tuple containing the train, validation, and test splits of the data.
"""
assert sum(sizes) == 1
# Split
train_size, test_size = sizes[0] * len(data), sizes[1] * len(data)
train, test = [], []
train_scaffold_count, test_scaffold_count = 0, 0
# Map from scaffold to index in the data
scaffold_to_indices = scaffold_to_smiles(data, use_indices=True)
if balanced: # Put stuff that's bigger than half the val/test size into train, rest just order randomly
index_sets = list(scaffold_to_indices.values())
big_index_sets = []
small_index_sets = []
for index_set in index_sets:
if len(index_set) > test_size / 2:
big_index_sets.append(index_set)
else:
small_index_sets.append(index_set)
random.seed(seed)
random.shuffle(big_index_sets)
random.shuffle(small_index_sets)
index_sets = big_index_sets + small_index_sets
else: # Sort from largest to smallest scaffold sets
index_sets = sorted(list(scaffold_to_indices.values()),
key=lambda index_set: len(index_set),
reverse=True)
for index_set in index_sets:
if len(train) + len(index_set) <= train_size:
train += index_set
train_scaffold_count += 1
else:
test += index_set
test_scaffold_count += 1
#print(f'Total scaffolds = {len(scaffold_to_indices):,} | '
# f'train scaffolds = {train_scaffold_count:,} | '
# f'test scaffolds = {test_scaffold_count:,}')
# Map from indices to data
#train = [data[i] for i in train]
#test = [data[i] for i in test]
#print(train)
#print(test)
return train, test