-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_dataset.py
executable file
·118 lines (100 loc) · 3.52 KB
/
prepare_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import gzip
import logging
from functools import partial
from multiprocessing import Pool
import pandas as pd
from tqdm.auto import tqdm
from rdkit import Chem
from moses.metrics import mol_passes_filters, compute_scaffold
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("prepare dataset")
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
'--output', '-o',
type=str, default='dataset_v1.csv',
help='Path for constructed dataset'
)
parser.add_argument(
'--seed', type=int, default=0, help='Random state'
)
parser.add_argument(
'--zinc', type=str,
default='../data/11_p0.smi.gz',
help='path to .smi.gz file with ZINC smiles'
)
parser.add_argument(
'--n_jobs', type=int, default=1,
help='number of processes to use'
)
parser.add_argument(
'--keep_ids', action='store_true', default=False,
help='Keep ZINC ids in the final csv file'
)
parser.add_argument(
'--isomeric', action='store_true', default=False,
help='Save isomeric SMILES (non-isomeric by default)'
)
return parser
def process_molecule(mol_row, isomeric):
mol_row = mol_row.decode('utf-8')
smiles, _id = mol_row.split()
if not mol_passes_filters(smiles):
return None
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles),
isomericSmiles=isomeric)
return _id, smiles
def unzip_dataset(path):
logger.info("Unzipping dataset")
with gzip.open(path) as smi:
lines = smi.readlines()
return lines
def filter_lines(lines, n_jobs, isomeric):
logger.info('Filtering SMILES')
with Pool(n_jobs) as pool:
process_molecule_p = partial(process_molecule, isomeric=isomeric)
dataset = [
x for x in tqdm(
pool.imap_unordered(process_molecule_p, lines),
total=len(lines),
miniters=1000
)
if x is not None
]
dataset = pd.DataFrame(dataset, columns=['ID', 'SMILES'])
dataset = dataset.sort_values(by=['ID', 'SMILES'])
dataset = dataset.drop_duplicates('ID')
dataset = dataset.sort_values(by='ID')
dataset = dataset.drop_duplicates('SMILES')
dataset['scaffold'] = pool.map(
compute_scaffold, dataset['SMILES'].values
)
return dataset
def split_dataset(dataset, seed):
logger.info('Splitting the dataset')
scaffolds = pd.value_counts(dataset['scaffold'])
scaffolds = sorted(scaffolds.items(), key=lambda x: (-x[1], x[0]))
test_scaffolds = set([x[0] for x in scaffolds[9::10]])
dataset['SPLIT'] = 'train'
test_scaf_idx = [x in test_scaffolds for x in dataset['scaffold']]
dataset.loc[test_scaf_idx, 'SPLIT'] = 'test_scaffolds'
test_idx = dataset.loc[dataset['SPLIT'] == 'train'].sample(
frac=0.1, random_state=seed
).index
dataset.loc[test_idx, 'SPLIT'] = 'test'
dataset.drop('scaffold', axis=1, inplace=True)
return dataset
def main(config):
lines = unzip_dataset(config.zinc)
dataset = filter_lines(lines, config.n_jobs, config.isomeric)
dataset = split_dataset(dataset, config.seed)
if not config.keep_ids:
dataset.drop('ID', 1, inplace=True)
dataset.to_csv(config.output, index=None)
if __name__ == '__main__':
parser = get_parser()
config, unknown = parser.parse_known_args()
if len(unknown) != 0:
raise ValueError("Unknown argument "+unknown[0])
main(config)